1from __future__ import annotations
2
3import concurrent.futures
4import itertools
5import math
6from collections import Counter, defaultdict
7from dataclasses import dataclass
8from typing import Optional
9
10import fastchrf
11import torch
12from sacrebleu.metrics.chrf import CHRF
13from sacrebleu.metrics.helpers import extract_all_char_ngrams, extract_word_ngrams
14from torch import Tensor
15
16from mbrs import timer
17
18from . import Metric, MetricAggregatable, register
19
20
[docs]
21@register("chrf")
22class MetricChrF(MetricAggregatable):
23 """ChrF metric class."""
24
[docs]
25 @dataclass
26 class Config(Metric.Config):
27 """ChrF metric configuration.
28
29 - char_order (int): Character n-gram order.
30 - word_order (int): Word n-gram order. If equals to 2, the metric is referred to as chrF++.
31 - beta (int): Determine the importance of recall w.r.t precision.
32 - lowercase (bool): Enable case-insensitivity.
33 - whitespace (bool): If `True`, include whitespaces when extracting character n-grams.
34 - eps_smoothing (bool): If `True`, applies epsilon smoothing similar to reference chrF++.py, NLTK and Moses implementations.
35 Otherwise, it takes into account effective match order similar to sacreBLEU < 2.0.0.
36 - num_workers (int): Number of workers for multiprocessing.
37 - fastchrf (bool): Use the rust implementation of chrF.
38 """
39
40 char_order: int = 6
41 word_order: int = 0
42 beta: int = 2
43 lowercase: bool = False
44 whitespace: bool = False
45 eps_smoothing: bool = False
46 num_workers: int = 8
47 fastchrf: bool = False
48
49 def __post_init__(self):
50 if self.fastchrf and self.word_order > 0:
51 raise ValueError("fastchrf does not support the `word_order` option.")
52
53 cfg: Config
54
[docs]
55 @dataclass
56 class AggregatedReference:
57 """Aggregated reference representation.
58
59 - ngrams (list[Counter]]): Bags of n-grams for each order.
60 """
61
62 ngrams: list[Counter]
63
64 def __init__(self, cfg: MetricChrF.Config):
65 super().__init__(cfg)
66 self.scorer = CHRF(
67 char_order=cfg.char_order,
68 word_order=cfg.word_order,
69 beta=cfg.beta,
70 lowercase=cfg.lowercase,
71 whitespace=cfg.whitespace,
72 eps_smoothing=cfg.eps_smoothing,
73 )
74
75 def _fastchrf_pairwise_scores(
76 self, hypotheses_lists: list[list[str]], references_lists: list[list[str]]
77 ) -> Tensor:
78 """Calculate the pairwise scores using fastchrf.
79
80 Args:
81 hypotheses_lists (list[list[str]]): N lists of hypotheses.
82 references_lists (list[list[str]]): N lists of references.
83
84 Returns:
85 Tensor: Score matrix of shape `(N, H, R)`, where `H` is the number
86 of hypotheses and `R` is the number of references.
87 """
88 return Tensor(
89 fastchrf.pairwise_chrf(
90 hypotheses_lists,
91 references_lists,
92 char_order=self.cfg.char_order,
93 beta=float(self.cfg.beta),
94 remove_whitespace=not self.cfg.whitespace,
95 eps_smoothing=self.cfg.eps_smoothing,
96 )
97 )
98
99 def _fastchrf_expected_scores_reference_aggregation(
100 self, hypotheses_lists: list[list[str]], references_lists: list[list[str]]
101 ) -> Tensor:
102 """Calculate the expected scores with reference aggregation using fastchrf.
103
104 Args:
105 hypotheses_lists (list[list[str]]): N lists of hypotheses.
106 references_lists (list[list[str]]): N lists of references.
107
108 Returns:
109 Tensor: Score matrix of shape `(N, H)`, where `H` is the number
110 of hypotheses.
111 """
112 return Tensor(
113 fastchrf.aggregate_chrf(
114 hypotheses_lists,
115 references_lists,
116 char_order=self.cfg.char_order,
117 beta=float(self.cfg.beta),
118 remove_whitespace=not self.cfg.whitespace,
119 eps_smoothing=self.cfg.eps_smoothing,
120 )
121 )
122
[docs]
123 def score(self, hypothesis: str, reference: str, *_, **__) -> float:
124 """Calculate the score of the given hypothesis.
125
126 Args:
127 hypothesis (str): Hypothesis.
128 reference (str): Reference.
129
130 Returns:
131 float: The score of the given hypothesis.
132 """
133 if self.cfg.fastchrf:
134 return self._fastchrf_pairwise_scores([[hypothesis]], [[reference]]).item()
135
136 return self.scorer.sentence_score(hypothesis, [reference]).score
137
[docs]
138 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor:
139 """Calculate the scores of the given hypotheses.
140
141 Args:
142 hypotheses (list[str]): N hypotheses.
143 references (list[str]): N references.
144
145 Returns:
146 Tensor: The N scores of the given hypotheses.
147 """
148 if self.cfg.fastchrf:
149 with timer.measure("score") as t:
150 t.set_delta_ncalls(len(hypotheses))
151 return self._fastchrf_pairwise_scores(
152 [[hypothesis] for hypothesis in hypotheses],
153 [[reference] for reference in references],
154 ).flatten()
155
156 with concurrent.futures.ProcessPoolExecutor(
157 max_workers=self.cfg.num_workers,
158 ) as executor:
159 with timer.measure("score") as t:
160 t.set_delta_ncalls(len(hypotheses))
161 return Tensor(
162 list(
163 executor.map(
164 self.score,
165 hypotheses,
166 references,
167 chunksize=math.ceil(len(hypotheses) / self.cfg.num_workers),
168 )
169 )
170 )
171
[docs]
172 def pairwise_scores(
173 self, hypotheses: list[str], references: list[str], *_, **__
174 ) -> Tensor:
175 """Calculate the pairwise scores.
176
177 Args:
178 hypotheses (list[str]): Hypotheses.
179 references (list[str]): References.
180
181 Returns:
182 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
183 of hypotheses and `R` is the number of references.
184 """
185 if self.cfg.fastchrf:
186 with timer.measure("score") as t:
187 t.set_delta_ncalls(len(hypotheses) * len(references))
188 return self._fastchrf_pairwise_scores(
189 [hypotheses], [references]
190 ).squeeze(0)
191
192 with concurrent.futures.ProcessPoolExecutor(
193 max_workers=self.cfg.num_workers
194 ) as executor:
195 with timer.measure("score") as t:
196 t.set_delta_ncalls(len(hypotheses) * len(references))
197
198 return Tensor(
199 list(
200 executor.map(
201 self.score,
202 *zip(*itertools.product(hypotheses, references)),
203 chunksize=len(hypotheses),
204 )
205 )
206 ).view(len(hypotheses), len(references))
207
[docs]
208 def corpus_score(
209 self,
210 hypotheses: list[str],
211 references_lists: list[list[str]],
212 sources: Optional[list[str]] = None,
213 ) -> float:
214 """Calculate the corpus-level score.
215
216 Args:
217 hypotheses (list[str]): Hypotheses.
218 references_lists (list[list[str]]): Lists of references.
219 sources (list[str], optional): Sources.
220
221 Returns:
222 float: The corpus score.
223 """
224 return self.scorer.corpus_score(hypotheses, references_lists).score
225
226 def _aggregate_references(
227 self, references: list[str], reference_lprobs: Optional[Tensor] = None
228 ) -> AggregatedReference:
229 """Aggregate references.
230
231 Args:
232 references (list[str]): References.
233 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
234 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
235
236 Returns:
237 MetricChrF.AggregatedReference: Aggregated reference representation.
238 """
239 num_references = len(references)
240 reference_ngrams: list[list[Counter[str]]] = self.scorer._cache_references(
241 [[ref] for ref in references]
242 )[0]["ref_ngrams"]
243
244 if reference_lprobs is not None:
245 lprobs = reference_lprobs.log_softmax(dim=-1).tolist()
246 else:
247 lprobs = [-math.log(num_references)] * num_references
248
249 acc_ngrams: defaultdict[int, Counter[str]] = defaultdict(Counter)
250 for i, ngrams in enumerate(reference_ngrams):
251 for order, ngram_counts in enumerate(ngrams):
252 for ngram in ngram_counts:
253 # Note: Counter has float values.
254 ngram_counts[ngram] = math.exp(
255 math.log(ngram_counts[ngram]) + lprobs[i]
256 )
257 acc_ngrams[order] += ngram_counts
258
259 return self.AggregatedReference(
260 [acc_ngrams[order] for order in range(len(acc_ngrams))]
261 )
262
[docs]
263 def expected_scores_reference_aggregation(
264 self,
265 hypotheses: list[str],
266 references: list[str],
267 source: Optional[str] = None,
268 reference_lprobs: Optional[Tensor] = None,
269 ) -> Tensor:
270 """Calculate the expected scores for each hypothesis.
271
272 Args:
273 hypotheses (list[str]): Hypotheses.
274 references (list[str]): References.
275 source (str, optional): A source.
276 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
277 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
278
279 Returns:
280 Tensor: The expected scores for each hypothesis.
281 """
282 if self.cfg.fastchrf:
283 if reference_lprobs is not None:
284 raise ValueError("fastchrf does not support model-based aggregation.")
285
286 with timer.measure("expectation"):
287 return self._fastchrf_expected_scores_reference_aggregation(
288 [hypotheses], [references]
289 ).squeeze(0)
290
291 with timer.measure("aggregate/references"):
292 aggregated_reference = self._aggregate_references(
293 references, reference_lprobs=reference_lprobs
294 )
295
296 expected_scores = torch.zeros((len(hypotheses),))
297 for i, hypothesis in enumerate(hypotheses):
298 with timer.measure("expectation"):
299 hypothesis = self.scorer._preprocess_segment(hypothesis)
300
301 # Extract character n-grams
302 all_hyp_ngrams = extract_all_char_ngrams(
303 hypothesis, self.scorer.char_order, self.scorer.whitespace
304 )
305
306 # Check chrF+ mode to see if we'll add word n-grams as well
307 if self.scorer.word_order > 0:
308 # Primitive tokenization: separate out punctuations
309 hwords = self.scorer._remove_punctuation(hypothesis)
310 _range = range(1, self.scorer.word_order + 1)
311 all_hyp_ngrams.extend(
312 [extract_word_ngrams(hwords, n) for n in _range]
313 )
314
315 stats = []
316 # Traverse all orders
317 for h, r in zip(all_hyp_ngrams, aggregated_reference.ngrams):
318 stats.extend(self.scorer._get_match_statistics(h, r))
319 f_score = self.scorer._compute_f_score(stats)
320 expected_scores[i] = f_score
321
322 return expected_scores