Source code for mbrs.decoders.probabilistic_mbr

  1from __future__ import annotations
  2
  3import math
  4from dataclasses import dataclass
  5from typing import Optional
  6
  7import torch
  8from torch import Tensor
  9
 10from mbrs import functional, timer
 11from mbrs.metrics import MetricCacheable
 12from mbrs.modules.als import MatrixFactorizationALS
 13
 14from . import register
 15from .mbr import DecoderMBR
 16
 17
[docs] 18@register("probabilistic_mbr") 19class DecoderProbabilisticMBR(DecoderMBR): 20 """Probabilistic MBR decoder using alternating least squares (ALS) approximation. 21 22 References: 23 F. Trabelsi et al., 2024, 24 "Efficient Minimum Bayes Risk Decoding using Low-Rank Matrix Completion Algorithms". 25 https://arxiv.org/abs/2406.02832 26 """ 27 28 cfg: Config 29
[docs] 30 @dataclass 31 class Config(DecoderMBR.Config): 32 """Configuration for the decoder. 33 34 - reduction_factor (float): Reduction factor. 35 The computational budget will be reduced to `1 / reduction_factor`. 36 - regularization_weight (float): Weight of L2 regularization. 37 - rank (int): Rank of the factarized matrices. 38 - niter (int): The number of alternating steps performed. 39 - seed (int): Random seed. 40 """ 41 42 reduction_factor: float = 8.0 43 regularization_weight: float = 0.1 44 rank: int = 8 45 niter: int = 10 46 seed: int = 0
47
[docs] 48 def pairwise_scores_probabilistic( 49 self, 50 hypotheses: list[str], 51 references: list[str], 52 source: Optional[str] = None, 53 ) -> Tensor: 54 """Compute approximated pairwise scores using the probabilistic MBR algorithm. 55 56 Args: 57 hypotheses (list[str]): Hypotheses. 58 references (list[str]): References. 59 source (str, optional): A source. 60 61 Returns: 62 Tensor: Approximated pairwise scores of shape `(H, R)`. 63 """ 64 rng = torch.Generator().manual_seed(self.cfg.seed) 65 H = len(hypotheses) 66 R = len(references) 67 num_ucalcs = math.ceil(H * R / self.cfg.reduction_factor) 68 69 pairwise_scores = torch.zeros((H, R), device=self.metric.device) 70 pairwise_sample_indices = torch.randperm(H * R, generator=rng)[:num_ucalcs] 71 hypothesis_sample_indices: list[int] = (pairwise_sample_indices // R).tolist() 72 reference_sample_indices: list[int] = (pairwise_sample_indices % R).tolist() 73 74 # Algorithm 2 in the paper. 75 if isinstance(self.metric, MetricCacheable): 76 with timer.measure("encode/hypotheses"): 77 hypotheses_ir = self.metric.encode(hypotheses) 78 if hypotheses == references: 79 references_ir = hypotheses_ir 80 else: 81 with timer.measure("encode/references"): 82 references_ir = self.metric.encode(references) 83 if source is None: 84 source_ir = None 85 else: 86 with timer.measure("encode/source"): 87 source_ir = self.metric.encode([source]) 88 89 num_hyp_samples = len(hypothesis_sample_indices) 90 for i in range(0, num_hyp_samples, H): 91 pairwise_scores[ 92 hypothesis_sample_indices[i : i + H], 93 reference_sample_indices[i : i + H], 94 ] = self.metric.scores_from_ir( 95 hypotheses_ir[hypothesis_sample_indices[i : i + H]], 96 references_ir[reference_sample_indices[i : i + H]], 97 source_ir.repeat(min(H, num_hyp_samples - i)) 98 if source_ir is not None 99 else None, 100 ).float() 101 else: 102 hypothesis_samples = [hypotheses[i] for i in hypothesis_sample_indices] 103 reference_samples = [references[j] for j in reference_sample_indices] 104 pairwise_scores[hypothesis_sample_indices, reference_sample_indices] = ( 105 self.metric.scores( 106 hypothesis_samples, 107 reference_samples, 108 [source] * len(hypothesis_samples) if source is not None else None, 109 ).float() 110 ) 111 observed_mask = pairwise_scores.new_zeros((H, R), dtype=torch.bool) 112 observed_mask[hypothesis_sample_indices, reference_sample_indices] = True 113 114 # Algorithm 1 in the paper. 115 als = MatrixFactorizationALS( 116 regularization_weight=self.cfg.regularization_weight, rank=self.cfg.rank 117 ) 118 X, Y = als.factorize( 119 pairwise_scores, 120 observed_mask=observed_mask, 121 niter=self.cfg.niter, 122 seed=self.cfg.seed, 123 ) 124 reconstructed_pairwise_scores = X @ Y.T 125 return reconstructed_pairwise_scores
126
[docs] 127 def decode( 128 self, 129 hypotheses: list[str], 130 references: list[str], 131 source: Optional[str] = None, 132 nbest: int = 1, 133 reference_lprobs: Optional[Tensor] = None, 134 ) -> DecoderMBR.Output: 135 """Select the n-best hypotheses based on the strategy. 136 137 Args: 138 hypotheses (list[str]): Hypotheses. 139 references (list[str]): References. 140 source (str, optional): A source. 141 nbest (int): Return the n-best hypotheses. 142 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 143 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 144 145 Returns: 146 DecoderMBR.Output: The n-best hypotheses. 147 """ 148 149 if self.cfg.reduction_factor <= 1.0: 150 expected_scores = self.metric.expected_scores( 151 hypotheses, references, source, reference_lprobs=reference_lprobs 152 ) 153 else: # Probabilistic MBR decoding 154 pairwise_scores = self.pairwise_scores_probabilistic( 155 hypotheses, references, source 156 ) 157 expected_scores = functional.expectation( 158 pairwise_scores, lprobs=reference_lprobs 159 ) 160 161 selector_outputs = self.select( 162 hypotheses, expected_scores, nbest=nbest, source=source 163 ) 164 return ( 165 self.Output( 166 idx=selector_outputs.idx, 167 sentence=selector_outputs.sentence, 168 score=selector_outputs.score, 169 ) 170 | selector_outputs 171 )