1from __future__ import annotations
2
3from dataclasses import dataclass, field
4from typing import Optional
5
6import torch
7from torch import Tensor
8
9from mbrs import functional, timer
10from mbrs.metrics import Metric, MetricCacheable
11from mbrs.selectors import SELECTOR_NBEST, Selector, SelectorNbest
12
13from . import register
14from .mbr import DecoderMBR
15
16
[docs]
17@register("pruning_mbr")
18class DecoderPruningMBR(DecoderMBR):
19 """Pruning MBR decoder class.
20
21 References:
22 J. Cheng and A. Vlachos, 2023,
23 "Faster Minimum Bayes Risk Decoding with Confidence-based Pruning".
24 https://aclanthology.org/2023.emnlp-main.767/
25 """
26
27 def __init__(
28 self,
29 cfg: DecoderPruningMBR.Config,
30 metric: Metric,
31 selector: Selector = SELECTOR_NBEST,
32 ) -> None:
33 if not isinstance(selector, SelectorNbest):
34 raise ValueError(
35 "Confidence-based pruning cannot be combined with other selectors than the nbest."
36 )
37 super().__init__(cfg, metric, selector)
38
[docs]
39 @dataclass
40 class Config(DecoderMBR.Config):
41 """Configuration for the decoder.
42
43 - alpha (float): Prune hypotheses based on this confidence threshold.
44 - sampling_shceduler (list[int]): Sample size scheduler. For each step, the
45 number of samples will be the t-th number.
46 - num_boostrap_samples (int): Number of boostrap samples.
47 - seed (int): Random seed for bootstrap sampling.
48 """
49
50 alpha: float = 0.99
51 sampling_scheduler: list[int] = field(
52 default_factory=lambda: [8, 16, 32, 64, 128, 256]
53 )
54 num_bootstrap_samples: int = 500
55 seed: int = 0
56
57 cfg: Config
58
[docs]
59 def decode_pruning(
60 self,
61 hypotheses: list[str],
62 references: list[str],
63 source: Optional[str] = None,
64 nbest: int = 1,
65 reference_lprobs: Optional[Tensor] = None,
66 ) -> tuple[list[float], list[int]]:
67 """Select the n-best hypotheses using pruning MBR decoding.
68
69 Args:
70 hypotheses (list[str]): Hypotheses.
71 references (list[str]): References.
72 source (str, optional): A source.
73 nbest (int): Return the n-best hypotheses.
74 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
75 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
76
77 Returns:
78 - list[float]: Top-k scores.
79 - list[int]: Top-k indices.
80 """
81 rng = torch.Generator(device=self.metric.device).manual_seed(self.cfg.seed)
82 H = len(hypotheses)
83 max_r = min(max(self.cfg.sampling_scheduler), len(references))
84 pairwise_scores = torch.zeros((H, max_r), device=self.metric.device)
85 orig_indices = torch.arange(H, device=self.metric.device)
86
87 if isinstance(self.metric, MetricCacheable):
88 with timer.measure("encode/hypotheses"):
89 hypotheses_ir = self.metric.encode(hypotheses)
90 references_ir = hypotheses_ir if hypotheses == references else None
91 if source is None:
92 source_ir = None
93 else:
94 with timer.measure("encode/source"):
95 source_ir = self.metric.encode([source])
96
97 with timer.measure("pruning_mbr"):
98 # Algorithm 1 in the paper.
99 prev_r = 0
100 for t, r in enumerate(self.cfg.sampling_scheduler):
101 r = min(r, len(references))
102 if r <= prev_r:
103 break
104
105 # Equation 5 and Algorithm 2 in the paper.
106 if isinstance(self.metric, MetricCacheable):
107 if references_ir is None:
108 with timer.measure("encode/references"):
109 references_ir_t = self.metric.encode(references[prev_r:r])
110 else:
111 references_ir_t = references_ir[prev_r:r]
112
113 pairwise_scores[:, prev_r:r] = self.metric.pairwise_scores_from_ir(
114 hypotheses_ir, references_ir_t, source_ir
115 )
116 else:
117 pairwise_scores[:, prev_r:r] = self.metric.pairwise_scores(
118 hypotheses, references[prev_r:r], source
119 )
120
121 expected_scores = functional.expectation(
122 pairwise_scores[:, :r],
123 lprobs=reference_lprobs[:r]
124 if reference_lprobs is not None
125 else None,
126 )
127 current_best_idx = self.argbest(expected_scores)
128 sample_indices = torch.randint(
129 r,
130 size=(self.cfg.num_bootstrap_samples, r),
131 device=self.metric.device,
132 generator=rng,
133 )
134 bootstrap_expected_scores = functional.expectation(
135 pairwise_scores[:, sample_indices],
136 lprobs=reference_lprobs[sample_indices]
137 if reference_lprobs is not None
138 else None,
139 )
140 num_wins = (
141 (
142 bootstrap_expected_scores
143 >= bootstrap_expected_scores[current_best_idx]
144 )
145 if self.maximize
146 else (
147 bootstrap_expected_scores
148 <= bootstrap_expected_scores[current_best_idx]
149 )
150 )
151 win_rates = num_wins.float().mean(dim=1)
152 winners = (win_rates > 1 - self.cfg.alpha).nonzero(as_tuple=True)[0]
153 num_winners = len(winners)
154 if num_winners >= nbest:
155 if isinstance(self.metric, MetricCacheable):
156 hypotheses_ir = hypotheses_ir[winners]
157 else:
158 hypotheses = [hypotheses[i] for i in winners]
159 pairwise_scores = pairwise_scores[winners]
160 orig_indices = orig_indices[winners]
161 prev_r = r
162 else:
163 break
164 expected_scores = functional.expectation(
165 pairwise_scores[:, :prev_r],
166 lprobs=reference_lprobs[:prev_r]
167 if reference_lprobs is not None
168 else None,
169 )
170
171 topk_scores, topk_indices = self.topk(expected_scores, k=nbest)
172 return topk_scores, orig_indices[topk_indices].tolist()
173
[docs]
174 def decode(
175 self,
176 hypotheses: list[str],
177 references: list[str],
178 source: Optional[str] = None,
179 nbest: int = 1,
180 reference_lprobs: Optional[Tensor] = None,
181 ) -> DecoderMBR.Output:
182 """Select the n-best hypotheses based on the strategy.
183
184 Args:
185 hypotheses (list[str]): Hypotheses.
186 references (list[str]): References.
187 source (str, optional): A source.
188 nbest (int): Return the n-best hypotheses.
189 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
190 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
191
192 Returns:
193 DecoderMBR.Output: The n-best hypotheses.
194 """
195
196 topk_scores, topk_indices = self.decode_pruning(
197 hypotheses,
198 references,
199 source,
200 nbest=nbest,
201 reference_lprobs=reference_lprobs,
202 )
203 return self.Output(
204 idx=topk_indices,
205 sentence=[hypotheses[idx] for idx in topk_indices],
206 score=topk_scores,
207 )