1from __future__ import annotations
2
3import abc
4from dataclasses import dataclass, fields, make_dataclass
5from typing import Any, Optional
6
7from torch import Tensor
8
9from mbrs import registry
10from mbrs.metrics.base import Metric, MetricBase, MetricReferenceless
11from mbrs.selectors import SELECTOR_NBEST, Selector
12
13
[docs]
14class DecoderBase(abc.ABC):
15 """Decoder base class."""
16
17 def __init__(
18 self,
19 cfg: DecoderBase.Config,
20 metric: MetricBase,
21 selector: Selector = SELECTOR_NBEST,
22 ) -> None:
23 self.cfg = cfg
24 self.metric = metric
25 self.selector = selector
26
27 @property
28 def maximize(self) -> bool:
29 """Return `True` when maximizing the objective score."""
30 return self.metric.HIGHER_IS_BETTER
31
[docs]
32 @dataclass
33 class Config:
34 """Configuration for the decoder."""
35
[docs]
36 @dataclass
37 class Output:
38 """
39 - idx (list[int]): Index numbers of the n-best hypotheses.
40 - sentence (list[str]): Sentences of the n-best hypotheses.
41 - score (list[float]): Scores of the n-best hypotheses.
42 """
43
44 idx: list[int]
45 sentence: list[str]
46 score: list[float]
47
48 def __or__(self, other: Any):
49 """Returns the union of dataclasses.
50
51 Args:
52 other (Any): An other dataclass.
53
54 Returns:
55 Output: New dataclass with the merged attributes of `self` and `other`.
56 """
57 new_fields = [(f.name, f.type, f) for f in fields(self)]
58 new_fields += [
59 (f.name, f.type, f)
60 for f in fields(other)
61 if f.name not in {f.name for f in fields(self)}
62 ]
63 new_dc_type = make_dataclass(
64 "Output", fields=new_fields, bases=(type(self),)
65 )
66 attrs = {f.name: getattr(other, f.name) for f in fields(other)} | {
67 f.name: getattr(self, f.name) for f in fields(self)
68 }
69 return new_dc_type(**attrs)
70
[docs]
71 def topk(self, x: Tensor, k: int = 1) -> tuple[list[float], list[int]]:
72 """Return the top-k best elements and corresponding indices.
73
74 Args:
75 x (Tensor): Input 1-D array.
76 k (int): Return the top-k values and indices.
77
78 Returns:
79 tuple[list[float], list[int]]
80 - list[float]: The top-k values.
81 - list[int]: The top-k indices.
82 """
83 return self.selector.topk(x, k=k, maximize=self.maximize)
84
[docs]
85 def argbest(self, x: Tensor) -> Tensor:
86 """Return the index of the best element.
87
88 Args:
89 x (Tensor): Input 1-D array.
90
91 Returns:
92 Tensor: A scalar tensor of the best index.
93 """
94 return self.selector.argbest(x, maximize=self.maximize)
95
[docs]
96 def superior(self, a: float, b: float) -> bool:
97 """Return whether the score `a` is superior to the score `b`.
98
99 Args:
100 a (float): A score.
101 b (float): A score.
102
103 Returns:
104 bool: Return True when `a` is superior to `b`.
105 """
106 return self.selector.superior(a, b, maximize=self.maximize)
107
[docs]
108 def select(
109 self,
110 hypotheses: list[str],
111 expected_scores: Tensor,
112 nbest: int = 1,
113 source: Optional[str] = None,
114 **kwargs,
115 ) -> Selector.Output:
116 """Select the final output list.
117
118 Args:
119 hypotheses (list[str]): Hypotheses.
120 expected_scores (Tensor): The expected scores for each hypothesis.
121 nbest (int): Return the n-best hypotheses based on the selection rule.
122 source (str, optional): A source.
123 maximize (bool): Whether maximize the scores or not.
124
125 Returns:
126 Selector.Output: Outputs.
127 """
128 return self.selector.select(
129 hypotheses,
130 expected_scores,
131 nbest=nbest,
132 source=source,
133 maximize=self.maximize,
134 **kwargs,
135 )
136
137
[docs]
138class DecoderReferenceBased(DecoderBase, metaclass=abc.ABCMeta):
139 """Decoder base class for strategies that use references like MBR decoding."""
140
141 metric: Metric
142
[docs]
143 @abc.abstractmethod
144 def decode(
145 self,
146 hypotheses: list[str],
147 references: list[str],
148 source: Optional[str] = None,
149 nbest: int = 1,
150 reference_lprobs: Optional[Tensor] = None,
151 ) -> DecoderReferenceBased.Output:
152 """Select the n-best hypotheses based on the strategy.
153
154 Args:
155 hypotheses (list[str]): Hypotheses.
156 references (list[str]): References.
157 source (str, optional): A source.
158 nbest (int): Return the n-best hypotheses.
159 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
160 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
161
162 Returns:
163 Decoder.Output: The n-best hypotheses.
164 """
165
166
[docs]
167class DecoderReferenceless(DecoderBase, metaclass=abc.ABCMeta):
168 """Decoder base class for reference-free strategies."""
169
170 metric: MetricReferenceless
171
[docs]
172 @abc.abstractmethod
173 def decode(
174 self, hypotheses: list[str], source: str, nbest: int = 1
175 ) -> DecoderReferenceless.Output:
176 """Select the n-best hypotheses based on the strategy.
177
178 Args:
179 hypotheses (list[str]): Hypotheses.
180 source (str): A source.
181 nbest (int): Return the n-best hypotheses.
182
183 Returns:
184 Decoder.Output: The n-best hypotheses.
185 """
186
187
188register, get_decoder = registry.Registry(
189 DecoderReferenceBased | DecoderReferenceless
190).get_closure()