Source code for mbrs.decoders.pruning_mbr

  1from __future__ import annotations
  2
  3from dataclasses import dataclass, field
  4from typing import Optional
  5
  6import torch
  7from torch import Tensor
  8
  9from mbrs import functional, timer
 10from mbrs.metrics import Metric, MetricCacheable
 11from mbrs.selectors import SELECTOR_NBEST, Selector, SelectorNbest
 12
 13from . import register
 14from .mbr import DecoderMBR
 15
 16
[docs] 17@register("pruning_mbr") 18class DecoderPruningMBR(DecoderMBR): 19 """Pruning MBR decoder class. 20 21 References: 22 J. Cheng and A. Vlachos, 2023, 23 "Faster Minimum Bayes Risk Decoding with Confidence-based Pruning". 24 https://aclanthology.org/2023.emnlp-main.767/ 25 """ 26 27 def __init__( 28 self, 29 cfg: DecoderPruningMBR.Config, 30 metric: Metric, 31 selector: Selector = SELECTOR_NBEST, 32 ) -> None: 33 if not isinstance(selector, SelectorNbest): 34 raise ValueError( 35 "Confidence-based pruning cannot be combined with other selectors than the nbest." 36 ) 37 super().__init__(cfg, metric, selector) 38
[docs] 39 @dataclass 40 class Config(DecoderMBR.Config): 41 """Configuration for the decoder. 42 43 - alpha (float): Prune hypotheses based on this confidence threshold. 44 - sampling_shceduler (list[int]): Sample size scheduler. For each step, the 45 number of samples will be the t-th number. 46 - num_boostrap_samples (int): Number of boostrap samples. 47 - seed (int): Random seed for bootstrap sampling. 48 """ 49 50 alpha: float = 0.99 51 sampling_scheduler: list[int] = field( 52 default_factory=lambda: [8, 16, 32, 64, 128, 256] 53 ) 54 num_bootstrap_samples: int = 500 55 seed: int = 0
56 57 cfg: Config 58
[docs] 59 def decode_pruning( 60 self, 61 hypotheses: list[str], 62 references: list[str], 63 source: Optional[str] = None, 64 nbest: int = 1, 65 reference_lprobs: Optional[Tensor] = None, 66 ) -> tuple[list[float], list[int]]: 67 """Select the n-best hypotheses using pruning MBR decoding. 68 69 Args: 70 hypotheses (list[str]): Hypotheses. 71 references (list[str]): References. 72 source (str, optional): A source. 73 nbest (int): Return the n-best hypotheses. 74 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 75 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 76 77 Returns: 78 - list[float]: Top-k scores. 79 - list[int]: Top-k indices. 80 """ 81 rng = torch.Generator(device=self.metric.device).manual_seed(self.cfg.seed) 82 H = len(hypotheses) 83 max_r = min(max(self.cfg.sampling_scheduler), len(references)) 84 pairwise_scores = torch.zeros((H, max_r), device=self.metric.device) 85 orig_indices = torch.arange(H, device=self.metric.device) 86 87 if isinstance(self.metric, MetricCacheable): 88 with timer.measure("encode/hypotheses"): 89 hypotheses_ir = self.metric.encode(hypotheses) 90 references_ir = hypotheses_ir if hypotheses == references else None 91 if source is None: 92 source_ir = None 93 else: 94 with timer.measure("encode/source"): 95 source_ir = self.metric.encode([source]) 96 97 with timer.measure("pruning_mbr"): 98 # Algorithm 1 in the paper. 99 prev_r = 0 100 for t, r in enumerate(self.cfg.sampling_scheduler): 101 r = min(r, len(references)) 102 if r <= prev_r: 103 break 104 105 # Equation 5 and Algorithm 2 in the paper. 106 if isinstance(self.metric, MetricCacheable): 107 if references_ir is None: 108 with timer.measure("encode/references"): 109 references_ir_t = self.metric.encode(references[prev_r:r]) 110 else: 111 references_ir_t = references_ir[prev_r:r] 112 113 pairwise_scores[:, prev_r:r] = self.metric.pairwise_scores_from_ir( 114 hypotheses_ir, references_ir_t, source_ir 115 ) 116 else: 117 pairwise_scores[:, prev_r:r] = self.metric.pairwise_scores( 118 hypotheses, references[prev_r:r], source 119 ) 120 121 expected_scores = functional.expectation( 122 pairwise_scores[:, :r], 123 lprobs=reference_lprobs[:r] 124 if reference_lprobs is not None 125 else None, 126 ) 127 current_best_idx = self.argbest(expected_scores) 128 sample_indices = torch.randint( 129 r, 130 size=(self.cfg.num_bootstrap_samples, r), 131 device=self.metric.device, 132 generator=rng, 133 ) 134 bootstrap_expected_scores = functional.expectation( 135 pairwise_scores[:, sample_indices], 136 lprobs=reference_lprobs[sample_indices] 137 if reference_lprobs is not None 138 else None, 139 ) 140 num_wins = ( 141 ( 142 bootstrap_expected_scores 143 >= bootstrap_expected_scores[current_best_idx] 144 ) 145 if self.maximize 146 else ( 147 bootstrap_expected_scores 148 <= bootstrap_expected_scores[current_best_idx] 149 ) 150 ) 151 win_rates = num_wins.float().mean(dim=1) 152 winners = (win_rates > 1 - self.cfg.alpha).nonzero(as_tuple=True)[0] 153 num_winners = len(winners) 154 if num_winners >= nbest: 155 if isinstance(self.metric, MetricCacheable): 156 hypotheses_ir = hypotheses_ir[winners] 157 else: 158 hypotheses = [hypotheses[i] for i in winners] 159 pairwise_scores = pairwise_scores[winners] 160 orig_indices = orig_indices[winners] 161 prev_r = r 162 else: 163 break 164 expected_scores = functional.expectation( 165 pairwise_scores[:, :prev_r], 166 lprobs=reference_lprobs[:prev_r] 167 if reference_lprobs is not None 168 else None, 169 ) 170 171 topk_scores, topk_indices = self.topk(expected_scores, k=nbest) 172 return topk_scores, orig_indices[topk_indices].tolist()
173
[docs] 174 def decode( 175 self, 176 hypotheses: list[str], 177 references: list[str], 178 source: Optional[str] = None, 179 nbest: int = 1, 180 reference_lprobs: Optional[Tensor] = None, 181 ) -> DecoderMBR.Output: 182 """Select the n-best hypotheses based on the strategy. 183 184 Args: 185 hypotheses (list[str]): Hypotheses. 186 references (list[str]): References. 187 source (str, optional): A source. 188 nbest (int): Return the n-best hypotheses. 189 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 190 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 191 192 Returns: 193 DecoderMBR.Output: The n-best hypotheses. 194 """ 195 196 topk_scores, topk_indices = self.decode_pruning( 197 hypotheses, 198 references, 199 source, 200 nbest=nbest, 201 reference_lprobs=reference_lprobs, 202 ) 203 return self.Output( 204 idx=topk_indices, 205 sentence=[hypotheses[idx] for idx in topk_indices], 206 score=topk_scores, 207 )