Source code for mbrs.functional

 1from typing import Optional
 2
 3import torch
 4from torch import Tensor
 5
 6
[docs] 7def expectation(matrix: Tensor, lprobs: Optional[Tensor] = None) -> Tensor: 8 """Compute expectation values for each row. 9 10 Args: 11 matrix (Tensor): Input matrix of shape `(H, R)`. 12 lprobs (Tensor, optional): Log-probabilities for each column of shape `(R,)`. 13 14 Returns: 15 Tensor: Expected values for each row of shape `(H,)`. 16 """ 17 if lprobs is None: 18 return matrix.mean(dim=-1) 19 else: 20 if list(lprobs.shape) != list(matrix.shape)[1:]: 21 raise ValueError( 22 f"`weights` must have {list(matrix.shape)[1:]} elements, but got {list(lprobs.shape)}" 23 ) 24 25 return ( 26 ( 27 matrix.float() 28 * lprobs.to(matrix).softmax(dim=-1, dtype=torch.float32)[None, :] 29 ) 30 .sum(dim=-1) 31 .to(matrix) 32 )