1from __future__ import annotations
2
3import concurrent.futures
4import itertools
5import math
6from collections import Counter
7from dataclasses import dataclass
8from typing import Optional
9
10import torch
11from sacrebleu.metrics.bleu import BLEU, MAX_NGRAM_ORDER
12from sacrebleu.metrics.helpers import extract_all_word_ngrams
13from torch import Tensor
14
15from mbrs import timer
16
17from . import MetricAggregatable, register
18
19
[docs]
20@register("bleu")
21class MetricBLEU(MetricAggregatable):
22 """BLEU metric class."""
23
[docs]
24 @dataclass
25 class Config(MetricAggregatable.Config):
26 """BLEU metric configuration.
27
28 - lowercase (bool): If True, lowercased BLEU is computed.
29 - force (bool): Ignore data that looks already tokenized.
30 - tokenize (str, optional): The tokenizer to use. If None, defaults to language-specific tokenizers with '13a' as the fallback default.
31 - smooth_method (str): The smoothing method to use ('floor', 'add-k', 'exp' or 'none').
32 - smooth_value (float, optional): The smoothing value for `floor` and `add-k` methods. `None` falls back to default value.
33 - max_ngram_order (int): If given, it overrides the maximum n-gram order (default: 4) when computing precisions.
34 - effective_order (bool): If `True`, stop including n-gram orders for which precision is 0.
35 This should be `True`, if sentence-level BLEU will be computed. (default: True)
36 - trg_lang (str): An optional language code to raise potential tokenizer warnings.
37 - num_workers (int): Number of workers for multiprocessing.
38 """
39
40 lowercase: bool = False
41 force: bool = False
42 tokenize: Optional[str] = None
43 smooth_method: str = "exp"
44 smooth_value: Optional[float] = None
45 max_ngram_order: int = 4
46 effective_order: bool = True
47 trg_lang: str = ""
48 num_workers: int = 8
49
50 cfg: Config
51
[docs]
52 @dataclass
53 class AggregatedReference:
54 """Aggregated reference representation.
55
56 - ngrams (Counter[tuple[str, ...]]): Bags of expected n-gram counts.
57 - length (float): Expected length of references.
58 """
59
60 ngrams: Counter[tuple[str, ...]]
61 length: float
62
63 def __init__(self, cfg: MetricBLEU.Config):
64 super().__init__(cfg)
65 self.scorer = self._initialize_bleu(cfg)
66
67 @staticmethod
68 def _initialize_bleu(cfg: MetricBLEU.Config) -> BLEU:
69 scorer = BLEU(
70 lowercase=cfg.lowercase,
71 force=cfg.force,
72 tokenize=cfg.tokenize,
73 smooth_method=cfg.smooth_method,
74 smooth_value=cfg.smooth_value,
75 max_ngram_order=cfg.max_ngram_order,
76 effective_order=cfg.effective_order,
77 trg_lang=cfg.trg_lang,
78 )
79 MetricBLEU._score_worker.scorer = scorer
80 return scorer
81
[docs]
82 def score(self, hypothesis: str, reference: str, *_, **__) -> float:
83 """Calculate the score of the given hypothesis.
84
85 Args:
86 hypothesis (str): Hypothesis.
87 reference (str): Reference.
88
89 Returns:
90 float: The score of the given hypothesis.
91 """
92 return self.scorer.sentence_score(hypothesis, [reference]).score
93
94 @staticmethod
95 def _score_worker(hypothesis: str, reference: str, *_, **__) -> float:
96 """Calculate the score of the given hypothesis.
97
98 Beacause ja-mecab tokenizer cannot be pickled, this method is necessary to use
99 multiprocessing.
100
101 Args:
102 hypothesis (str): Hypothesis.
103 reference (str): Reference.
104
105 Returns:
106 float: The score of the given hypothesis.
107
108 Todo:
109 - Replace this method with a better logic.
110 """
111 return MetricBLEU._score_worker.scorer.sentence_score(
112 hypothesis, [reference]
113 ).score
114
[docs]
115 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor:
116 """Calculate the scores of the given hypotheses.
117
118 Args:
119 hypotheses (list[str]): N hypotheses.
120 references (list[str]): N references.
121
122 Returns:
123 Tensor: The N scores of the given hypotheses.
124 """
125 with concurrent.futures.ProcessPoolExecutor(
126 max_workers=self.cfg.num_workers,
127 initializer=self._initialize_bleu,
128 initargs=(self.cfg,),
129 ) as executor:
130 with timer.measure("score") as t:
131 t.set_delta_ncalls(len(hypotheses))
132 return Tensor(
133 list(
134 executor.map(
135 self._score_worker,
136 hypotheses,
137 references,
138 chunksize=math.ceil(len(hypotheses) / self.cfg.num_workers),
139 )
140 )
141 )
142
[docs]
143 def pairwise_scores(
144 self, hypotheses: list[str], references: list[str], *_, **__
145 ) -> Tensor:
146 """Calculate the pairwise scores.
147
148 Args:
149 hypotheses (list[str]): Hypotheses.
150 references (list[str]): References.
151
152 Returns:
153 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
154 of hypotheses and `R` is the number of references.
155 """
156 with concurrent.futures.ProcessPoolExecutor(
157 max_workers=self.cfg.num_workers,
158 initializer=self._initialize_bleu,
159 initargs=(self.cfg,),
160 ) as executor:
161 with timer.measure("score") as t:
162 t.set_delta_ncalls(len(hypotheses) * len(references))
163
164 return Tensor(
165 list(
166 executor.map(
167 self._score_worker,
168 *zip(*itertools.product(hypotheses, references)),
169 chunksize=len(hypotheses),
170 )
171 )
172 ).view(len(hypotheses), len(references))
173
[docs]
174 def corpus_score(
175 self,
176 hypotheses: list[str],
177 references_lists: list[list[str]],
178 sources: Optional[list[str]] = None,
179 ) -> float:
180 """Calculate the corpus-level score.
181
182 Args:
183 hypotheses (list[str]): Hypotheses.
184 references_lists (list[list[str]]): Lists of references.
185 sources (list[str], optional): Sources.
186
187 Returns:
188 float: The corpus score.
189 """
190 return self.scorer.corpus_score(hypotheses, references_lists).score
191
192 @staticmethod
193 def _compute_bleu(
194 correct: list[float],
195 total: list[float],
196 sys_len: float,
197 ref_len: float,
198 smooth_method: str = "none",
199 smooth_value: Optional[float] = None,
200 effective_order: bool = False,
201 max_ngram_order: int = MAX_NGRAM_ORDER,
202 ) -> float:
203 """Computes BLEU score from its sufficient statistics with smoothing.
204
205 Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU",
206 Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346)
207
208 - none: No smoothing.
209 - floor: Method 1 (requires small positive value (0.1 in the paper) to be set)
210 - add-k: Method 2 (Generalizing Lin and Och, 2004)
211 - exp: Method 3 (NIST smoothing method i.e. in use with mteval-v13a.pl)
212
213 This method extends the original sacrebleu implementation to treat expected n-grams.
214
215 Args:
216 correct (list[float]): List of counts of correct ngrams, 1 <= n <= max_ngram_order.
217 total (list[float]): List of counts of total ngrams, 1 <= n <= max_ngram_order
218 sys_len (float): The cumulative system length
219 ref_len (float): The cumulative reference length
220 smooth_method (str): The smoothing method to use ('floor', 'add-k', 'exp' or 'none')
221 smooth_value (float, optional): The smoothing value for `floor` and `add-k` methods. `None` falls back to default value.
222 effective_order (bool): If `True`, stop including n-gram orders for which precision is 0. This should be
223 `True`, if sentence-level BLEU will be computed.
224 max_ngram_order (int): If given, it overrides the maximum n-gram order (default: 4) when computing precisions.
225
226 Returns:
227 float: A BLEU score.
228 """
229 assert smooth_method in BLEU.SMOOTH_DEFAULTS.keys(), (
230 "Unknown smooth_method {smooth_method!r}"
231 )
232
233 # Fetch the default value for floor and add-k
234 if smooth_value is None:
235 smooth_value = BLEU.SMOOTH_DEFAULTS[smooth_method]
236
237 # Compute brevity penalty
238 if sys_len < ref_len:
239 bp = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0
240 else:
241 bp = 1.0
242
243 # n-gram precisions
244 precisions = [0.0 for x in range(max_ngram_order)]
245
246 # Early stop if there are no matches (#141)
247 if not any(correct):
248 return 0.0
249
250 smooth_mteval = 1.0
251 eff_order = max_ngram_order
252 for n in range(1, len(precisions) + 1):
253 if smooth_method == "add-k" and n > 1:
254 correct[n - 1] += smooth_value
255 total[n - 1] += smooth_value
256
257 if total[n - 1] == 0:
258 break
259
260 # If the system guesses no i-grams, 1 <= i <= max_ngram_order,
261 # the BLEU score is 0 (technically undefined). This is a problem for sentence
262 # level BLEU or a corpus of short sentences, where systems will get
263 # no credit if sentence lengths fall under the max_ngram_order threshold.
264 # This fix scales max_ngram_order to the observed maximum order.
265 # It is only available through the API and off by default
266 if effective_order:
267 eff_order = n
268
269 if correct[n - 1] == 0:
270 if smooth_method == "exp":
271 smooth_mteval *= 2
272 precisions[n - 1] = 100.0 / (smooth_mteval * total[n - 1])
273 elif smooth_method == "floor":
274 precisions[n - 1] = 100.0 * smooth_value / total[n - 1]
275 else:
276 precisions[n - 1] = 100.0 * correct[n - 1] / total[n - 1]
277
278 # Compute BLEU score
279 score = bp * math.exp(
280 sum(
281 [
282 math.log(p) if p > 0.0 else -9999999999.0
283 for p in precisions[:eff_order]
284 ]
285 )
286 / eff_order
287 )
288
289 return score
290
291 def _aggregate_references(
292 self, references: list[str], reference_lprobs: Optional[Tensor] = None
293 ) -> AggregatedReference:
294 """Aggregate references.
295
296 Args:
297 references (list[str]): References.
298 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
299 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
300
301 Returns:
302 MetricBLEU.AggregatedReference: Aggregated reference representation.
303 """
304 num_references = len(references)
305 if reference_lprobs is not None:
306 lprobs = reference_lprobs.log_softmax(dim=-1, dtype=torch.float32).tolist()
307 else:
308 lprobs = [-math.log(num_references)] * num_references
309
310 reference_stats = self.scorer._cache_references([references])
311 reference_ngrams: list[Counter[tuple[str, ...]]] = [
312 stat["ref_ngrams"] for stat in reference_stats
313 ]
314
315 expected_reference_length = sum(
316 [
317 math.exp(math.log(stat["ref_lens"][0]) + lprob)
318 if stat["ref_lens"][0] > 0.0
319 else 0.0
320 for stat, lprob in zip(reference_stats, lprobs)
321 ]
322 )
323
324 acc_ngrams: Counter[tuple[str, ...]] = Counter()
325 for i, ngrams in enumerate(reference_ngrams):
326 for ngram in ngrams:
327 # Note: Counter has float values.
328 ngrams[ngram] = math.exp(math.log(ngrams[ngram]) + lprobs[i])
329 acc_ngrams += ngrams
330
331 return self.AggregatedReference(acc_ngrams, expected_reference_length)
332
[docs]
333 def expected_scores_reference_aggregation(
334 self,
335 hypotheses: list[str],
336 references: list[str],
337 source: Optional[str] = None,
338 reference_lprobs: Optional[Tensor] = None,
339 ) -> Tensor:
340 """Calculate the expected scores for each hypothesis.
341
342 Args:
343 hypotheses (list[str]): Hypotheses.
344 references (list[str]): References.
345 source (str, optional): A source.
346 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
347 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
348
349 Returns:
350 Tensor: The expected scores for each hypothesis.
351 """
352 with timer.measure("aggregate/references"):
353 aggregated_reference = self._aggregate_references(
354 references, reference_lprobs=reference_lprobs
355 )
356
357 expected_scores = torch.zeros((len(hypotheses),))
358 for i, hypothesis in enumerate(hypotheses):
359 with timer.measure("expectation"):
360 hypothesis = self.scorer._preprocess_segment(hypothesis)
361 # Extract n-grams for the hypothesis
362 hyp_ngrams, hyp_len = extract_all_word_ngrams(
363 hypothesis, 1, self.scorer.max_ngram_order
364 )
365
366 # Count the stats
367 # Although counter has its internal & and | operators, this is faster
368 correct = [0.0 for i in range(self.scorer.max_ngram_order)]
369 total = correct[:]
370 for hyp_ngram, hyp_count in hyp_ngrams.items():
371 # n-gram order
372 n = len(hyp_ngram) - 1
373 # count hypothesis n-grams
374 total[n] += float(hyp_count)
375 # count matched n-grams
376 if hyp_ngram in aggregated_reference.ngrams:
377 correct[n] += float(
378 min(hyp_count, aggregated_reference.ngrams[hyp_ngram])
379 )
380
381 stats = [hyp_len, aggregated_reference.length] + correct + total
382 expected_scores[i] = self._compute_bleu(
383 correct=stats[2 : 2 + self.scorer.max_ngram_order],
384 total=stats[2 + self.scorer.max_ngram_order :],
385 sys_len=float(stats[0]),
386 ref_len=float(stats[1]),
387 smooth_method=self.scorer.smooth_method,
388 smooth_value=self.scorer.smooth_value,
389 effective_order=self.scorer.effective_order,
390 max_ngram_order=self.scorer.max_ngram_order,
391 )
392
393 return expected_scores