Source code for mbrs.decoders.base

  1from __future__ import annotations
  2
  3import abc
  4from dataclasses import dataclass, fields, make_dataclass
  5from typing import Any, Optional
  6
  7from torch import Tensor
  8
  9from mbrs import registry
 10from mbrs.metrics.base import Metric, MetricBase, MetricReferenceless
 11from mbrs.selectors import SELECTOR_NBEST, Selector
 12
 13
[docs] 14class DecoderBase(abc.ABC): 15 """Decoder base class.""" 16 17 def __init__( 18 self, 19 cfg: DecoderBase.Config, 20 metric: MetricBase, 21 selector: Selector = SELECTOR_NBEST, 22 ) -> None: 23 self.cfg = cfg 24 self.metric = metric 25 self.selector = selector 26 27 @property 28 def maximize(self) -> bool: 29 """Return `True` when maximizing the objective score.""" 30 return self.metric.HIGHER_IS_BETTER 31
[docs] 32 @dataclass 33 class Config: 34 """Configuration for the decoder."""
35
[docs] 36 @dataclass 37 class Output: 38 """ 39 - idx (list[int]): Index numbers of the n-best hypotheses. 40 - sentence (list[str]): Sentences of the n-best hypotheses. 41 - score (list[float]): Scores of the n-best hypotheses. 42 """ 43 44 idx: list[int] 45 sentence: list[str] 46 score: list[float] 47 48 def __or__(self, other: Any): 49 """Returns the union of dataclasses. 50 51 Args: 52 other (Any): An other dataclass. 53 54 Returns: 55 Output: New dataclass with the merged attributes of `self` and `other`. 56 """ 57 new_fields = [(f.name, f.type, f) for f in fields(self)] 58 new_fields += [ 59 (f.name, f.type, f) 60 for f in fields(other) 61 if f.name not in {f.name for f in fields(self)} 62 ] 63 new_dc_type = make_dataclass( 64 "Output", fields=new_fields, bases=(type(self),) 65 ) 66 attrs = {f.name: getattr(other, f.name) for f in fields(other)} | { 67 f.name: getattr(self, f.name) for f in fields(self) 68 } 69 return new_dc_type(**attrs)
70
[docs] 71 def topk(self, x: Tensor, k: int = 1) -> tuple[list[float], list[int]]: 72 """Return the top-k best elements and corresponding indices. 73 74 Args: 75 x (Tensor): Input 1-D array. 76 k (int): Return the top-k values and indices. 77 78 Returns: 79 tuple[list[float], list[int]] 80 - list[float]: The top-k values. 81 - list[int]: The top-k indices. 82 """ 83 return self.selector.topk(x, k=k, maximize=self.maximize)
84
[docs] 85 def argbest(self, x: Tensor) -> Tensor: 86 """Return the index of the best element. 87 88 Args: 89 x (Tensor): Input 1-D array. 90 91 Returns: 92 Tensor: A scalar tensor of the best index. 93 """ 94 return self.selector.argbest(x, maximize=self.maximize)
95
[docs] 96 def superior(self, a: float, b: float) -> bool: 97 """Return whether the score `a` is superior to the score `b`. 98 99 Args: 100 a (float): A score. 101 b (float): A score. 102 103 Returns: 104 bool: Return True when `a` is superior to `b`. 105 """ 106 return self.selector.superior(a, b, maximize=self.maximize)
107
[docs] 108 def select( 109 self, 110 hypotheses: list[str], 111 expected_scores: Tensor, 112 nbest: int = 1, 113 source: Optional[str] = None, 114 **kwargs, 115 ) -> Selector.Output: 116 """Select the final output list. 117 118 Args: 119 hypotheses (list[str]): Hypotheses. 120 expected_scores (Tensor): The expected scores for each hypothesis. 121 nbest (int): Return the n-best hypotheses based on the selection rule. 122 source (str, optional): A source. 123 maximize (bool): Whether maximize the scores or not. 124 125 Returns: 126 Selector.Output: Outputs. 127 """ 128 return self.selector.select( 129 hypotheses, 130 expected_scores, 131 nbest=nbest, 132 source=source, 133 maximize=self.maximize, 134 **kwargs, 135 )
136 137
[docs] 138class DecoderReferenceBased(DecoderBase, metaclass=abc.ABCMeta): 139 """Decoder base class for strategies that use references like MBR decoding.""" 140 141 metric: Metric 142
[docs] 143 @abc.abstractmethod 144 def decode( 145 self, 146 hypotheses: list[str], 147 references: list[str], 148 source: Optional[str] = None, 149 nbest: int = 1, 150 reference_lprobs: Optional[Tensor] = None, 151 ) -> DecoderReferenceBased.Output: 152 """Select the n-best hypotheses based on the strategy. 153 154 Args: 155 hypotheses (list[str]): Hypotheses. 156 references (list[str]): References. 157 source (str, optional): A source. 158 nbest (int): Return the n-best hypotheses. 159 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 160 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 161 162 Returns: 163 Decoder.Output: The n-best hypotheses. 164 """
165 166
[docs] 167class DecoderReferenceless(DecoderBase, metaclass=abc.ABCMeta): 168 """Decoder base class for reference-free strategies.""" 169 170 metric: MetricReferenceless 171
[docs] 172 @abc.abstractmethod 173 def decode( 174 self, hypotheses: list[str], source: str, nbest: int = 1 175 ) -> DecoderReferenceless.Output: 176 """Select the n-best hypotheses based on the strategy. 177 178 Args: 179 hypotheses (list[str]): Hypotheses. 180 source (str): A source. 181 nbest (int): Return the n-best hypotheses. 182 183 Returns: 184 Decoder.Output: The n-best hypotheses. 185 """
186 187 188register, get_decoder = registry.Registry( 189 DecoderReferenceBased | DecoderReferenceless 190).get_closure()