Source code for mbrs.selectors.base
1from __future__ import annotations
2
3import abc
4from dataclasses import dataclass
5from typing import Optional
6
7import torch
8from torch import Tensor
9
10from mbrs import registry
11
12
[docs]
13class Selector(abc.ABC):
14 """Selector base class."""
15
16 def __init__(self, cfg: Selector.Config) -> None:
17 self.cfg = cfg
18
[docs]
19 @dataclass
20 class Config:
21 """Configuration for the selector."""
22
[docs]
23 @dataclass
24 class Output:
25 """
26 - idx (list[int]): Index numbers of the n-best hypotheses.
27 - sentence (list[str]): Sentences of the n-best hypotheses.
28 - score (list[float]): Scores of the n-best hypotheses.
29 """
30
31 idx: list[int]
32 sentence: list[str]
33 score: list[float]
34
[docs]
35 def topk(
36 self, x: Tensor, k: int = 1, maximize: bool = True
37 ) -> tuple[list[float], list[int]]:
38 """Return the top-k best elements and corresponding indices.
39
40 Args:
41 x (Tensor): Input 1-D array.
42 k (int): Return the top-k values and indices.
43 maximize (bool): Whether maximize the scores or not.
44
45 Returns:
46 tuple[list[float], list[int]]
47 - list[float]: The top-k values.
48 - list[int]: The top-k indices.
49 """
50 values, indices = torch.topk(x, k=min(k, len(x)), largest=maximize)
51 return values.tolist(), indices.tolist()
52
[docs]
53 def argbest(self, x: Tensor, maximize: bool = True) -> Tensor:
54 """Return the index of the best element.
55
56 Args:
57 x (Tensor): Input 1-D array.
58 maximize (bool): Whether maximize the scores or not.
59
60 Returns:
61 Tensor: A scalar tensor of the best index.
62 """
63 if maximize:
64 return torch.argmax(x)
65 return torch.argmin(x)
66
[docs]
67 def superior(self, a: float, b: float, maximize: bool = True) -> bool:
68 """Return whether the score `a` is superior to the score `b`.
69
70 Args:
71 a (float): A score.
72 b (float): A score.
73 maximize (bool): Whether maximize the scores or not.
74
75 Returns:
76 bool: Return True when `a` is superior to `b`.
77 """
78 if maximize:
79 return a > b
80 return a < b
81
[docs]
82 @abc.abstractmethod
83 def select(
84 self,
85 hypotheses: list[str],
86 expected_scores: Tensor,
87 nbest: int = 1,
88 source: Optional[str] = None,
89 maximize: bool = True,
90 **kwargs,
91 ) -> Selector.Output:
92 """Select the final output list.
93
94 Args:
95 hypotheses (list[str]): Hypotheses.
96 expected_scores (Tensor): The expected scores for each hypothesis.
97 nbest (int): Return the n-best hypotheses based on the selection rule.
98 source (str, optional): A source.
99 maximize (bool): Whether maximize the scores or not.
100
101 Returns:
102 Selector.Output: Selected hypotheses.
103 """
104
105
106register, get_selector = registry.Registry(Selector).get_closure()