Source code for mbrs.functional
1from typing import Optional
2
3import torch
4from torch import Tensor
5
6
[docs]
7def expectation(matrix: Tensor, lprobs: Optional[Tensor] = None) -> Tensor:
8 """Compute expectation values for each row.
9
10 Args:
11 matrix (Tensor): Input matrix of shape `(H, R)`.
12 lprobs (Tensor, optional): Log-probabilities for each column of shape `(R,)`.
13
14 Returns:
15 Tensor: Expected values for each row of shape `(H,)`.
16 """
17 if lprobs is None:
18 return matrix.mean(dim=-1)
19 else:
20 if list(lprobs.shape) != list(matrix.shape)[1:]:
21 raise ValueError(
22 f"`weights` must have {list(matrix.shape)[1:]} elements, but got {list(lprobs.shape)}"
23 )
24
25 return (
26 (
27 matrix.float()
28 * lprobs.to(matrix).softmax(dim=-1, dtype=torch.float32)[None, :]
29 )
30 .sum(dim=-1)
31 .to(matrix)
32 )