Source code for mbrs.selectors.nbest

 1from __future__ import annotations
 2
 3from typing import Optional
 4
 5from torch import Tensor
 6
 7from mbrs.selectors import Selector, register
 8
 9
[docs] 10@register("nbest") 11class SelectorNbest(Selector):
[docs] 12 def select( 13 self, 14 hypotheses: list[str], 15 expected_scores: Tensor, 16 nbest: int = 1, 17 source: Optional[str] = None, 18 maximize: bool = True, 19 **kwargs, 20 ) -> SelectorNbest.Output: 21 """Select the final output list. 22 23 Args: 24 hypotheses (list[str]): Hypotheses. 25 expected_scores (Tensor): The expected scores for each hypothesis. 26 nbest (int): Return the n-best hypotheses based on the selection rule. 27 source (str, optional): A source. 28 maximize (bool): Whether maximize the scores or not. 29 30 Returns: 31 Selector.Output: Selected hypotheses. 32 """ 33 nbest = min(len(hypotheses), nbest) 34 topk_scores, topk_indices = self.topk( 35 expected_scores, k=nbest, maximize=maximize 36 ) 37 return self.Output( 38 idx=topk_indices, 39 sentence=[hypotheses[idx] for idx in topk_indices], 40 score=topk_scores, 41 )