Source code for mbrs.decoders.mbr

 1from __future__ import annotations
 2
 3from typing import Optional
 4
 5from torch import Tensor
 6
 7from . import DecoderReferenceBased, register
 8
 9
[docs] 10@register("mbr") 11class DecoderMBR(DecoderReferenceBased): 12 """Naive MBR decoder class. 13 14 - Time complexity: O(N^2) 15 - Space complexity: O(N^2) 16 17 References: 18 S. Kumar and W. Byrne, 2004, 19 "Minimum Bayes-Risk Decoding for Statistical Machine Translation". 20 https://aclanthology.org/N04-1022 21 22 B. Eikema and W. Aziz, 2020, 23 "Is MAP Decoding All You Need? 24 The Inadequacy of the Mode in Neural Machine Translation". 25 https://aclanthology.org/2020.coling-main.398 26 """ 27
[docs] 28 def decode( 29 self, 30 hypotheses: list[str], 31 references: list[str], 32 source: Optional[str] = None, 33 nbest: int = 1, 34 reference_lprobs: Optional[Tensor] = None, 35 ) -> DecoderMBR.Output: 36 """Select the n-best hypotheses based on the strategy. 37 38 Args: 39 hypotheses (list[str]): Hypotheses. 40 references (list[str]): References. 41 source (str, optional): A source. 42 nbest (int): Return the n-best hypotheses. 43 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 44 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 45 46 Returns: 47 DecoderMBR.Output: The n-best hypotheses. 48 """ 49 expected_scores = self.metric.expected_scores( 50 hypotheses, references, source, reference_lprobs=reference_lprobs 51 ) 52 selector_outputs = self.select( 53 hypotheses, expected_scores, nbest=nbest, source=source 54 ) 55 return ( 56 self.Output( 57 idx=selector_outputs.idx, 58 sentence=selector_outputs.sentence, 59 score=selector_outputs.score, 60 ) 61 | selector_outputs 62 )