Source code for mbrs.decoders.aggregate_mbr
1from __future__ import annotations
2
3from typing import Optional
4
5from torch import Tensor
6
7from mbrs.metrics import MetricAggregatable
8
9from . import register
10from .mbr import DecoderMBR
11
12
[docs]
13@register("aggregate_mbr")
14class DecoderAggregateMBR(DecoderMBR):
15 """MBR decoding with reference aggregation.
16
17 - Time complexity: O(N)
18 - Space complexity: O(N)
19
20 References:
21 J. DeNero et al., 2009,
22 "Fast Consensus Decoding over Translation Forests".
23 https://aclanthology.org/P09-1064/
24
25 J. Vamvas and R. Sennrich, 2024,
26 "Linear-time Minimum Bayes Risk Decoding with Reference Aggregation".
27 https://arxiv.org/abs/2402.04251
28 """
29
[docs]
30 def decode(
31 self,
32 hypotheses: list[str],
33 references: list[str],
34 source: Optional[str] = None,
35 nbest: int = 1,
36 reference_lprobs: Optional[Tensor] = None,
37 ) -> DecoderAggregateMBR.Output:
38 """Select the n-best hypotheses based on the strategy.
39
40 Args:
41 hypotheses (list[str]): Hypotheses.
42 references (list[str]): References.
43 source (str, optional): A source.
44 nbest (int): Return the n-best hypotheses.
45 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
46 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
47
48 Returns:
49 DecoderAggregateMBR.Output: The n-best hypotheses.
50 """
51 assert isinstance(self.metric, MetricAggregatable)
52 expected_scores = self.metric.expected_scores_reference_aggregation(
53 hypotheses, references, source=source, reference_lprobs=reference_lprobs
54 )
55 selector_outputs = self.select(
56 hypotheses, expected_scores, nbest=nbest, source=source
57 )
58 return (
59 self.Output(
60 idx=selector_outputs.idx,
61 sentence=selector_outputs.sentence,
62 score=selector_outputs.score,
63 )
64 | selector_outputs
65 )