Source code for mbrs.decoders.mbr
1from __future__ import annotations
2
3from typing import Optional
4
5from torch import Tensor
6
7from . import DecoderReferenceBased, register
8
9
[docs]
10@register("mbr")
11class DecoderMBR(DecoderReferenceBased):
12 """Naive MBR decoder class.
13
14 - Time complexity: O(N^2)
15 - Space complexity: O(N^2)
16
17 References:
18 S. Kumar and W. Byrne, 2004,
19 "Minimum Bayes-Risk Decoding for Statistical Machine Translation".
20 https://aclanthology.org/N04-1022
21
22 B. Eikema and W. Aziz, 2020,
23 "Is MAP Decoding All You Need?
24 The Inadequacy of the Mode in Neural Machine Translation".
25 https://aclanthology.org/2020.coling-main.398
26 """
27
[docs]
28 def decode(
29 self,
30 hypotheses: list[str],
31 references: list[str],
32 source: Optional[str] = None,
33 nbest: int = 1,
34 reference_lprobs: Optional[Tensor] = None,
35 ) -> DecoderMBR.Output:
36 """Select the n-best hypotheses based on the strategy.
37
38 Args:
39 hypotheses (list[str]): Hypotheses.
40 references (list[str]): References.
41 source (str, optional): A source.
42 nbest (int): Return the n-best hypotheses.
43 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
44 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
45
46 Returns:
47 DecoderMBR.Output: The n-best hypotheses.
48 """
49 expected_scores = self.metric.expected_scores(
50 hypotheses, references, source, reference_lprobs=reference_lprobs
51 )
52 selector_outputs = self.select(
53 hypotheses, expected_scores, nbest=nbest, source=source
54 )
55 return (
56 self.Output(
57 idx=selector_outputs.idx,
58 sentence=selector_outputs.sentence,
59 score=selector_outputs.score,
60 )
61 | selector_outputs
62 )