Source code for mbrs.metrics.bleu

  1from __future__ import annotations
  2
  3import concurrent.futures
  4import itertools
  5import math
  6from collections import Counter
  7from dataclasses import dataclass
  8from typing import Optional
  9
 10import torch
 11from sacrebleu.metrics.bleu import BLEU, MAX_NGRAM_ORDER
 12from sacrebleu.metrics.helpers import extract_all_word_ngrams
 13from torch import Tensor
 14
 15from mbrs import timer
 16
 17from . import MetricAggregatable, register
 18
 19
[docs] 20@register("bleu") 21class MetricBLEU(MetricAggregatable): 22 """BLEU metric class.""" 23
[docs] 24 @dataclass 25 class Config(MetricAggregatable.Config): 26 """BLEU metric configuration. 27 28 - lowercase (bool): If True, lowercased BLEU is computed. 29 - force (bool): Ignore data that looks already tokenized. 30 - tokenize (str, optional): The tokenizer to use. If None, defaults to language-specific tokenizers with '13a' as the fallback default. 31 - smooth_method (str): The smoothing method to use ('floor', 'add-k', 'exp' or 'none'). 32 - smooth_value (float, optional): The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. 33 - max_ngram_order (int): If given, it overrides the maximum n-gram order (default: 4) when computing precisions. 34 - effective_order (bool): If `True`, stop including n-gram orders for which precision is 0. 35 This should be `True`, if sentence-level BLEU will be computed. (default: True) 36 - trg_lang (str): An optional language code to raise potential tokenizer warnings. 37 - num_workers (int): Number of workers for multiprocessing. 38 """ 39 40 lowercase: bool = False 41 force: bool = False 42 tokenize: Optional[str] = None 43 smooth_method: str = "exp" 44 smooth_value: Optional[float] = None 45 max_ngram_order: int = 4 46 effective_order: bool = True 47 trg_lang: str = "" 48 num_workers: int = 8
49 50 cfg: Config 51
[docs] 52 @dataclass 53 class AggregatedReference: 54 """Aggregated reference representation. 55 56 - ngrams (Counter[tuple[str, ...]]): Bags of expected n-gram counts. 57 - length (float): Expected length of references. 58 """ 59 60 ngrams: Counter[tuple[str, ...]] 61 length: float
62 63 def __init__(self, cfg: MetricBLEU.Config): 64 super().__init__(cfg) 65 self.scorer = self._initialize_bleu(cfg) 66 67 @staticmethod 68 def _initialize_bleu(cfg: MetricBLEU.Config) -> BLEU: 69 scorer = BLEU( 70 lowercase=cfg.lowercase, 71 force=cfg.force, 72 tokenize=cfg.tokenize, 73 smooth_method=cfg.smooth_method, 74 smooth_value=cfg.smooth_value, 75 max_ngram_order=cfg.max_ngram_order, 76 effective_order=cfg.effective_order, 77 trg_lang=cfg.trg_lang, 78 ) 79 MetricBLEU._score_worker.scorer = scorer 80 return scorer 81
[docs] 82 def score(self, hypothesis: str, reference: str, *_, **__) -> float: 83 """Calculate the score of the given hypothesis. 84 85 Args: 86 hypothesis (str): Hypothesis. 87 reference (str): Reference. 88 89 Returns: 90 float: The score of the given hypothesis. 91 """ 92 return self.scorer.sentence_score(hypothesis, [reference]).score
93 94 @staticmethod 95 def _score_worker(hypothesis: str, reference: str, *_, **__) -> float: 96 """Calculate the score of the given hypothesis. 97 98 Beacause ja-mecab tokenizer cannot be pickled, this method is necessary to use 99 multiprocessing. 100 101 Args: 102 hypothesis (str): Hypothesis. 103 reference (str): Reference. 104 105 Returns: 106 float: The score of the given hypothesis. 107 108 Todo: 109 - Replace this method with a better logic. 110 """ 111 return MetricBLEU._score_worker.scorer.sentence_score( 112 hypothesis, [reference] 113 ).score 114
[docs] 115 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor: 116 """Calculate the scores of the given hypotheses. 117 118 Args: 119 hypotheses (list[str]): N hypotheses. 120 references (list[str]): N references. 121 122 Returns: 123 Tensor: The N scores of the given hypotheses. 124 """ 125 with concurrent.futures.ProcessPoolExecutor( 126 max_workers=self.cfg.num_workers, 127 initializer=self._initialize_bleu, 128 initargs=(self.cfg,), 129 ) as executor: 130 with timer.measure("score") as t: 131 t.set_delta_ncalls(len(hypotheses)) 132 return Tensor( 133 list( 134 executor.map( 135 self._score_worker, 136 hypotheses, 137 references, 138 chunksize=math.ceil(len(hypotheses) / self.cfg.num_workers), 139 ) 140 ) 141 )
142
[docs] 143 def pairwise_scores( 144 self, hypotheses: list[str], references: list[str], *_, **__ 145 ) -> Tensor: 146 """Calculate the pairwise scores. 147 148 Args: 149 hypotheses (list[str]): Hypotheses. 150 references (list[str]): References. 151 152 Returns: 153 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 154 of hypotheses and `R` is the number of references. 155 """ 156 with concurrent.futures.ProcessPoolExecutor( 157 max_workers=self.cfg.num_workers, 158 initializer=self._initialize_bleu, 159 initargs=(self.cfg,), 160 ) as executor: 161 with timer.measure("score") as t: 162 t.set_delta_ncalls(len(hypotheses) * len(references)) 163 164 return Tensor( 165 list( 166 executor.map( 167 self._score_worker, 168 *zip(*itertools.product(hypotheses, references)), 169 chunksize=len(hypotheses), 170 ) 171 ) 172 ).view(len(hypotheses), len(references))
173
[docs] 174 def corpus_score( 175 self, 176 hypotheses: list[str], 177 references_lists: list[list[str]], 178 sources: Optional[list[str]] = None, 179 ) -> float: 180 """Calculate the corpus-level score. 181 182 Args: 183 hypotheses (list[str]): Hypotheses. 184 references_lists (list[list[str]]): Lists of references. 185 sources (list[str], optional): Sources. 186 187 Returns: 188 float: The corpus score. 189 """ 190 return self.scorer.corpus_score(hypotheses, references_lists).score
191 192 @staticmethod 193 def _compute_bleu( 194 correct: list[float], 195 total: list[float], 196 sys_len: float, 197 ref_len: float, 198 smooth_method: str = "none", 199 smooth_value: Optional[float] = None, 200 effective_order: bool = False, 201 max_ngram_order: int = MAX_NGRAM_ORDER, 202 ) -> float: 203 """Computes BLEU score from its sufficient statistics with smoothing. 204 205 Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", 206 Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) 207 208 - none: No smoothing. 209 - floor: Method 1 (requires small positive value (0.1 in the paper) to be set) 210 - add-k: Method 2 (Generalizing Lin and Och, 2004) 211 - exp: Method 3 (NIST smoothing method i.e. in use with mteval-v13a.pl) 212 213 This method extends the original sacrebleu implementation to treat expected n-grams. 214 215 Args: 216 correct (list[float]): List of counts of correct ngrams, 1 <= n <= max_ngram_order. 217 total (list[float]): List of counts of total ngrams, 1 <= n <= max_ngram_order 218 sys_len (float): The cumulative system length 219 ref_len (float): The cumulative reference length 220 smooth_method (str): The smoothing method to use ('floor', 'add-k', 'exp' or 'none') 221 smooth_value (float, optional): The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. 222 effective_order (bool): If `True`, stop including n-gram orders for which precision is 0. This should be 223 `True`, if sentence-level BLEU will be computed. 224 max_ngram_order (int): If given, it overrides the maximum n-gram order (default: 4) when computing precisions. 225 226 Returns: 227 float: A BLEU score. 228 """ 229 assert smooth_method in BLEU.SMOOTH_DEFAULTS.keys(), ( 230 "Unknown smooth_method {smooth_method!r}" 231 ) 232 233 # Fetch the default value for floor and add-k 234 if smooth_value is None: 235 smooth_value = BLEU.SMOOTH_DEFAULTS[smooth_method] 236 237 # Compute brevity penalty 238 if sys_len < ref_len: 239 bp = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 240 else: 241 bp = 1.0 242 243 # n-gram precisions 244 precisions = [0.0 for x in range(max_ngram_order)] 245 246 # Early stop if there are no matches (#141) 247 if not any(correct): 248 return 0.0 249 250 smooth_mteval = 1.0 251 eff_order = max_ngram_order 252 for n in range(1, len(precisions) + 1): 253 if smooth_method == "add-k" and n > 1: 254 correct[n - 1] += smooth_value 255 total[n - 1] += smooth_value 256 257 if total[n - 1] == 0: 258 break 259 260 # If the system guesses no i-grams, 1 <= i <= max_ngram_order, 261 # the BLEU score is 0 (technically undefined). This is a problem for sentence 262 # level BLEU or a corpus of short sentences, where systems will get 263 # no credit if sentence lengths fall under the max_ngram_order threshold. 264 # This fix scales max_ngram_order to the observed maximum order. 265 # It is only available through the API and off by default 266 if effective_order: 267 eff_order = n 268 269 if correct[n - 1] == 0: 270 if smooth_method == "exp": 271 smooth_mteval *= 2 272 precisions[n - 1] = 100.0 / (smooth_mteval * total[n - 1]) 273 elif smooth_method == "floor": 274 precisions[n - 1] = 100.0 * smooth_value / total[n - 1] 275 else: 276 precisions[n - 1] = 100.0 * correct[n - 1] / total[n - 1] 277 278 # Compute BLEU score 279 score = bp * math.exp( 280 sum( 281 [ 282 math.log(p) if p > 0.0 else -9999999999.0 283 for p in precisions[:eff_order] 284 ] 285 ) 286 / eff_order 287 ) 288 289 return score 290 291 def _aggregate_references( 292 self, references: list[str], reference_lprobs: Optional[Tensor] = None 293 ) -> AggregatedReference: 294 """Aggregate references. 295 296 Args: 297 references (list[str]): References. 298 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 299 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 300 301 Returns: 302 MetricBLEU.AggregatedReference: Aggregated reference representation. 303 """ 304 num_references = len(references) 305 if reference_lprobs is not None: 306 lprobs = reference_lprobs.log_softmax(dim=-1, dtype=torch.float32).tolist() 307 else: 308 lprobs = [-math.log(num_references)] * num_references 309 310 reference_stats = self.scorer._cache_references([references]) 311 reference_ngrams: list[Counter[tuple[str, ...]]] = [ 312 stat["ref_ngrams"] for stat in reference_stats 313 ] 314 315 expected_reference_length = sum( 316 [ 317 math.exp(math.log(stat["ref_lens"][0]) + lprob) 318 if stat["ref_lens"][0] > 0.0 319 else 0.0 320 for stat, lprob in zip(reference_stats, lprobs) 321 ] 322 ) 323 324 acc_ngrams: Counter[tuple[str, ...]] = Counter() 325 for i, ngrams in enumerate(reference_ngrams): 326 for ngram in ngrams: 327 # Note: Counter has float values. 328 ngrams[ngram] = math.exp(math.log(ngrams[ngram]) + lprobs[i]) 329 acc_ngrams += ngrams 330 331 return self.AggregatedReference(acc_ngrams, expected_reference_length) 332
[docs] 333 def expected_scores_reference_aggregation( 334 self, 335 hypotheses: list[str], 336 references: list[str], 337 source: Optional[str] = None, 338 reference_lprobs: Optional[Tensor] = None, 339 ) -> Tensor: 340 """Calculate the expected scores for each hypothesis. 341 342 Args: 343 hypotheses (list[str]): Hypotheses. 344 references (list[str]): References. 345 source (str, optional): A source. 346 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 347 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 348 349 Returns: 350 Tensor: The expected scores for each hypothesis. 351 """ 352 with timer.measure("aggregate/references"): 353 aggregated_reference = self._aggregate_references( 354 references, reference_lprobs=reference_lprobs 355 ) 356 357 expected_scores = torch.zeros((len(hypotheses),)) 358 for i, hypothesis in enumerate(hypotheses): 359 with timer.measure("expectation"): 360 hypothesis = self.scorer._preprocess_segment(hypothesis) 361 # Extract n-grams for the hypothesis 362 hyp_ngrams, hyp_len = extract_all_word_ngrams( 363 hypothesis, 1, self.scorer.max_ngram_order 364 ) 365 366 # Count the stats 367 # Although counter has its internal & and | operators, this is faster 368 correct = [0.0 for i in range(self.scorer.max_ngram_order)] 369 total = correct[:] 370 for hyp_ngram, hyp_count in hyp_ngrams.items(): 371 # n-gram order 372 n = len(hyp_ngram) - 1 373 # count hypothesis n-grams 374 total[n] += float(hyp_count) 375 # count matched n-grams 376 if hyp_ngram in aggregated_reference.ngrams: 377 correct[n] += float( 378 min(hyp_count, aggregated_reference.ngrams[hyp_ngram]) 379 ) 380 381 stats = [hyp_len, aggregated_reference.length] + correct + total 382 expected_scores[i] = self._compute_bleu( 383 correct=stats[2 : 2 + self.scorer.max_ngram_order], 384 total=stats[2 + self.scorer.max_ngram_order :], 385 sys_len=float(stats[0]), 386 ref_len=float(stats[1]), 387 smooth_method=self.scorer.smooth_method, 388 smooth_value=self.scorer.smooth_value, 389 effective_order=self.scorer.effective_order, 390 max_ngram_order=self.scorer.max_ngram_order, 391 ) 392 393 return expected_scores