1from __future__ import annotations
2
3import math
4from dataclasses import dataclass
5from typing import Optional
6
7import torch
8from torch import Tensor
9
10from mbrs import functional, timer
11from mbrs.metrics import MetricCacheable
12from mbrs.modules.als import MatrixFactorizationALS
13
14from . import register
15from .mbr import DecoderMBR
16
17
[docs]
18@register("probabilistic_mbr")
19class DecoderProbabilisticMBR(DecoderMBR):
20 """Probabilistic MBR decoder using alternating least squares (ALS) approximation.
21
22 References:
23 F. Trabelsi et al., 2024,
24 "Efficient Minimum Bayes Risk Decoding using Low-Rank Matrix Completion Algorithms".
25 https://arxiv.org/abs/2406.02832
26 """
27
28 cfg: Config
29
[docs]
30 @dataclass
31 class Config(DecoderMBR.Config):
32 """Configuration for the decoder.
33
34 - reduction_factor (float): Reduction factor.
35 The computational budget will be reduced to `1 / reduction_factor`.
36 - regularization_weight (float): Weight of L2 regularization.
37 - rank (int): Rank of the factarized matrices.
38 - niter (int): The number of alternating steps performed.
39 - seed (int): Random seed.
40 """
41
42 reduction_factor: float = 8.0
43 regularization_weight: float = 0.1
44 rank: int = 8
45 niter: int = 10
46 seed: int = 0
47
[docs]
48 def pairwise_scores_probabilistic(
49 self,
50 hypotheses: list[str],
51 references: list[str],
52 source: Optional[str] = None,
53 ) -> Tensor:
54 """Compute approximated pairwise scores using the probabilistic MBR algorithm.
55
56 Args:
57 hypotheses (list[str]): Hypotheses.
58 references (list[str]): References.
59 source (str, optional): A source.
60
61 Returns:
62 Tensor: Approximated pairwise scores of shape `(H, R)`.
63 """
64 rng = torch.Generator().manual_seed(self.cfg.seed)
65 H = len(hypotheses)
66 R = len(references)
67 num_ucalcs = math.ceil(H * R / self.cfg.reduction_factor)
68
69 pairwise_scores = torch.zeros((H, R), device=self.metric.device)
70 pairwise_sample_indices = torch.randperm(H * R, generator=rng)[:num_ucalcs]
71 hypothesis_sample_indices: list[int] = (pairwise_sample_indices // R).tolist()
72 reference_sample_indices: list[int] = (pairwise_sample_indices % R).tolist()
73
74 # Algorithm 2 in the paper.
75 if isinstance(self.metric, MetricCacheable):
76 with timer.measure("encode/hypotheses"):
77 hypotheses_ir = self.metric.encode(hypotheses)
78 if hypotheses == references:
79 references_ir = hypotheses_ir
80 else:
81 with timer.measure("encode/references"):
82 references_ir = self.metric.encode(references)
83 if source is None:
84 source_ir = None
85 else:
86 with timer.measure("encode/source"):
87 source_ir = self.metric.encode([source])
88
89 num_hyp_samples = len(hypothesis_sample_indices)
90 for i in range(0, num_hyp_samples, H):
91 pairwise_scores[
92 hypothesis_sample_indices[i : i + H],
93 reference_sample_indices[i : i + H],
94 ] = self.metric.scores_from_ir(
95 hypotheses_ir[hypothesis_sample_indices[i : i + H]],
96 references_ir[reference_sample_indices[i : i + H]],
97 source_ir.repeat(min(H, num_hyp_samples - i))
98 if source_ir is not None
99 else None,
100 ).float()
101 else:
102 hypothesis_samples = [hypotheses[i] for i in hypothesis_sample_indices]
103 reference_samples = [references[j] for j in reference_sample_indices]
104 pairwise_scores[hypothesis_sample_indices, reference_sample_indices] = (
105 self.metric.scores(
106 hypothesis_samples,
107 reference_samples,
108 [source] * len(hypothesis_samples) if source is not None else None,
109 ).float()
110 )
111 observed_mask = pairwise_scores.new_zeros((H, R), dtype=torch.bool)
112 observed_mask[hypothesis_sample_indices, reference_sample_indices] = True
113
114 # Algorithm 1 in the paper.
115 als = MatrixFactorizationALS(
116 regularization_weight=self.cfg.regularization_weight, rank=self.cfg.rank
117 )
118 X, Y = als.factorize(
119 pairwise_scores,
120 observed_mask=observed_mask,
121 niter=self.cfg.niter,
122 seed=self.cfg.seed,
123 )
124 reconstructed_pairwise_scores = X @ Y.T
125 return reconstructed_pairwise_scores
126
[docs]
127 def decode(
128 self,
129 hypotheses: list[str],
130 references: list[str],
131 source: Optional[str] = None,
132 nbest: int = 1,
133 reference_lprobs: Optional[Tensor] = None,
134 ) -> DecoderMBR.Output:
135 """Select the n-best hypotheses based on the strategy.
136
137 Args:
138 hypotheses (list[str]): Hypotheses.
139 references (list[str]): References.
140 source (str, optional): A source.
141 nbest (int): Return the n-best hypotheses.
142 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
143 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
144
145 Returns:
146 DecoderMBR.Output: The n-best hypotheses.
147 """
148
149 if self.cfg.reduction_factor <= 1.0:
150 expected_scores = self.metric.expected_scores(
151 hypotheses, references, source, reference_lprobs=reference_lprobs
152 )
153 else: # Probabilistic MBR decoding
154 pairwise_scores = self.pairwise_scores_probabilistic(
155 hypotheses, references, source
156 )
157 expected_scores = functional.expectation(
158 pairwise_scores, lprobs=reference_lprobs
159 )
160
161 selector_outputs = self.select(
162 hypotheses, expected_scores, nbest=nbest, source=source
163 )
164 return (
165 self.Output(
166 idx=selector_outputs.idx,
167 sentence=selector_outputs.sentence,
168 score=selector_outputs.score,
169 )
170 | selector_outputs
171 )