Source code for mbrs.selectors.nbest
1from __future__ import annotations
2
3from typing import Optional
4
5from torch import Tensor
6
7from mbrs.selectors import Selector, register
8
9
[docs]
10@register("nbest")
11class SelectorNbest(Selector):
[docs]
12 def select(
13 self,
14 hypotheses: list[str],
15 expected_scores: Tensor,
16 nbest: int = 1,
17 source: Optional[str] = None,
18 maximize: bool = True,
19 **kwargs,
20 ) -> SelectorNbest.Output:
21 """Select the final output list.
22
23 Args:
24 hypotheses (list[str]): Hypotheses.
25 expected_scores (Tensor): The expected scores for each hypothesis.
26 nbest (int): Return the n-best hypotheses based on the selection rule.
27 source (str, optional): A source.
28 maximize (bool): Whether maximize the scores or not.
29
30 Returns:
31 Selector.Output: Selected hypotheses.
32 """
33 nbest = min(len(hypotheses), nbest)
34 topk_scores, topk_indices = self.topk(
35 expected_scores, k=nbest, maximize=maximize
36 )
37 return self.Output(
38 idx=topk_indices,
39 sentence=[hypotheses[idx] for idx in topk_indices],
40 score=topk_scores,
41 )