Source code for mbrs.decoders.aggregate_mbr

 1from __future__ import annotations
 2
 3from typing import Optional
 4
 5from torch import Tensor
 6
 7from mbrs.metrics import MetricAggregatable
 8
 9from . import register
10from .mbr import DecoderMBR
11
12
[docs] 13@register("aggregate_mbr") 14class DecoderAggregateMBR(DecoderMBR): 15 """MBR decoding with reference aggregation. 16 17 - Time complexity: O(N) 18 - Space complexity: O(N) 19 20 References: 21 J. DeNero et al., 2009, 22 "Fast Consensus Decoding over Translation Forests". 23 https://aclanthology.org/P09-1064/ 24 25 J. Vamvas and R. Sennrich, 2024, 26 "Linear-time Minimum Bayes Risk Decoding with Reference Aggregation". 27 https://arxiv.org/abs/2402.04251 28 """ 29
[docs] 30 def decode( 31 self, 32 hypotheses: list[str], 33 references: list[str], 34 source: Optional[str] = None, 35 nbest: int = 1, 36 reference_lprobs: Optional[Tensor] = None, 37 ) -> DecoderAggregateMBR.Output: 38 """Select the n-best hypotheses based on the strategy. 39 40 Args: 41 hypotheses (list[str]): Hypotheses. 42 references (list[str]): References. 43 source (str, optional): A source. 44 nbest (int): Return the n-best hypotheses. 45 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 46 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 47 48 Returns: 49 DecoderAggregateMBR.Output: The n-best hypotheses. 50 """ 51 assert isinstance(self.metric, MetricAggregatable) 52 expected_scores = self.metric.expected_scores_reference_aggregation( 53 hypotheses, references, source=source, reference_lprobs=reference_lprobs 54 ) 55 selector_outputs = self.select( 56 hypotheses, expected_scores, nbest=nbest, source=source 57 ) 58 return ( 59 self.Output( 60 idx=selector_outputs.idx, 61 sentence=selector_outputs.sentence, 62 score=selector_outputs.score, 63 ) 64 | selector_outputs 65 )