Source code for mbrs.metrics.chrf

  1from __future__ import annotations
  2
  3import concurrent.futures
  4import itertools
  5import math
  6from collections import Counter, defaultdict
  7from dataclasses import dataclass
  8from typing import Optional
  9
 10import fastchrf
 11import torch
 12from sacrebleu.metrics.chrf import CHRF
 13from sacrebleu.metrics.helpers import extract_all_char_ngrams, extract_word_ngrams
 14from torch import Tensor
 15
 16from mbrs import timer
 17
 18from . import Metric, MetricAggregatable, register
 19
 20
[docs] 21@register("chrf") 22class MetricChrF(MetricAggregatable): 23 """ChrF metric class.""" 24
[docs] 25 @dataclass 26 class Config(Metric.Config): 27 """ChrF metric configuration. 28 29 - char_order (int): Character n-gram order. 30 - word_order (int): Word n-gram order. If equals to 2, the metric is referred to as chrF++. 31 - beta (int): Determine the importance of recall w.r.t precision. 32 - lowercase (bool): Enable case-insensitivity. 33 - whitespace (bool): If `True`, include whitespaces when extracting character n-grams. 34 - eps_smoothing (bool): If `True`, applies epsilon smoothing similar to reference chrF++.py, NLTK and Moses implementations. 35 Otherwise, it takes into account effective match order similar to sacreBLEU < 2.0.0. 36 - num_workers (int): Number of workers for multiprocessing. 37 - fastchrf (bool): Use the rust implementation of chrF. 38 """ 39 40 char_order: int = 6 41 word_order: int = 0 42 beta: int = 2 43 lowercase: bool = False 44 whitespace: bool = False 45 eps_smoothing: bool = False 46 num_workers: int = 8 47 fastchrf: bool = False 48 49 def __post_init__(self): 50 if self.fastchrf and self.word_order > 0: 51 raise ValueError("fastchrf does not support the `word_order` option.")
52 53 cfg: Config 54
[docs] 55 @dataclass 56 class AggregatedReference: 57 """Aggregated reference representation. 58 59 - ngrams (list[Counter]]): Bags of n-grams for each order. 60 """ 61 62 ngrams: list[Counter]
63 64 def __init__(self, cfg: MetricChrF.Config): 65 super().__init__(cfg) 66 self.scorer = CHRF( 67 char_order=cfg.char_order, 68 word_order=cfg.word_order, 69 beta=cfg.beta, 70 lowercase=cfg.lowercase, 71 whitespace=cfg.whitespace, 72 eps_smoothing=cfg.eps_smoothing, 73 ) 74 75 def _fastchrf_pairwise_scores( 76 self, hypotheses_lists: list[list[str]], references_lists: list[list[str]] 77 ) -> Tensor: 78 """Calculate the pairwise scores using fastchrf. 79 80 Args: 81 hypotheses_lists (list[list[str]]): N lists of hypotheses. 82 references_lists (list[list[str]]): N lists of references. 83 84 Returns: 85 Tensor: Score matrix of shape `(N, H, R)`, where `H` is the number 86 of hypotheses and `R` is the number of references. 87 """ 88 return Tensor( 89 fastchrf.pairwise_chrf( 90 hypotheses_lists, 91 references_lists, 92 char_order=self.cfg.char_order, 93 beta=float(self.cfg.beta), 94 remove_whitespace=not self.cfg.whitespace, 95 eps_smoothing=self.cfg.eps_smoothing, 96 ) 97 ) 98 99 def _fastchrf_expected_scores_reference_aggregation( 100 self, hypotheses_lists: list[list[str]], references_lists: list[list[str]] 101 ) -> Tensor: 102 """Calculate the expected scores with reference aggregation using fastchrf. 103 104 Args: 105 hypotheses_lists (list[list[str]]): N lists of hypotheses. 106 references_lists (list[list[str]]): N lists of references. 107 108 Returns: 109 Tensor: Score matrix of shape `(N, H)`, where `H` is the number 110 of hypotheses. 111 """ 112 return Tensor( 113 fastchrf.aggregate_chrf( 114 hypotheses_lists, 115 references_lists, 116 char_order=self.cfg.char_order, 117 beta=float(self.cfg.beta), 118 remove_whitespace=not self.cfg.whitespace, 119 eps_smoothing=self.cfg.eps_smoothing, 120 ) 121 ) 122
[docs] 123 def score(self, hypothesis: str, reference: str, *_, **__) -> float: 124 """Calculate the score of the given hypothesis. 125 126 Args: 127 hypothesis (str): Hypothesis. 128 reference (str): Reference. 129 130 Returns: 131 float: The score of the given hypothesis. 132 """ 133 if self.cfg.fastchrf: 134 return self._fastchrf_pairwise_scores([[hypothesis]], [[reference]]).item() 135 136 return self.scorer.sentence_score(hypothesis, [reference]).score
137
[docs] 138 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor: 139 """Calculate the scores of the given hypotheses. 140 141 Args: 142 hypotheses (list[str]): N hypotheses. 143 references (list[str]): N references. 144 145 Returns: 146 Tensor: The N scores of the given hypotheses. 147 """ 148 if self.cfg.fastchrf: 149 with timer.measure("score") as t: 150 t.set_delta_ncalls(len(hypotheses)) 151 return self._fastchrf_pairwise_scores( 152 [[hypothesis] for hypothesis in hypotheses], 153 [[reference] for reference in references], 154 ).flatten() 155 156 with concurrent.futures.ProcessPoolExecutor( 157 max_workers=self.cfg.num_workers, 158 ) as executor: 159 with timer.measure("score") as t: 160 t.set_delta_ncalls(len(hypotheses)) 161 return Tensor( 162 list( 163 executor.map( 164 self.score, 165 hypotheses, 166 references, 167 chunksize=math.ceil(len(hypotheses) / self.cfg.num_workers), 168 ) 169 ) 170 )
171
[docs] 172 def pairwise_scores( 173 self, hypotheses: list[str], references: list[str], *_, **__ 174 ) -> Tensor: 175 """Calculate the pairwise scores. 176 177 Args: 178 hypotheses (list[str]): Hypotheses. 179 references (list[str]): References. 180 181 Returns: 182 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 183 of hypotheses and `R` is the number of references. 184 """ 185 if self.cfg.fastchrf: 186 with timer.measure("score") as t: 187 t.set_delta_ncalls(len(hypotheses) * len(references)) 188 return self._fastchrf_pairwise_scores( 189 [hypotheses], [references] 190 ).squeeze(0) 191 192 with concurrent.futures.ProcessPoolExecutor( 193 max_workers=self.cfg.num_workers 194 ) as executor: 195 with timer.measure("score") as t: 196 t.set_delta_ncalls(len(hypotheses) * len(references)) 197 198 return Tensor( 199 list( 200 executor.map( 201 self.score, 202 *zip(*itertools.product(hypotheses, references)), 203 chunksize=len(hypotheses), 204 ) 205 ) 206 ).view(len(hypotheses), len(references))
207
[docs] 208 def corpus_score( 209 self, 210 hypotheses: list[str], 211 references_lists: list[list[str]], 212 sources: Optional[list[str]] = None, 213 ) -> float: 214 """Calculate the corpus-level score. 215 216 Args: 217 hypotheses (list[str]): Hypotheses. 218 references_lists (list[list[str]]): Lists of references. 219 sources (list[str], optional): Sources. 220 221 Returns: 222 float: The corpus score. 223 """ 224 return self.scorer.corpus_score(hypotheses, references_lists).score
225 226 def _aggregate_references( 227 self, references: list[str], reference_lprobs: Optional[Tensor] = None 228 ) -> AggregatedReference: 229 """Aggregate references. 230 231 Args: 232 references (list[str]): References. 233 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 234 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 235 236 Returns: 237 MetricChrF.AggregatedReference: Aggregated reference representation. 238 """ 239 num_references = len(references) 240 reference_ngrams: list[list[Counter[str]]] = self.scorer._cache_references( 241 [[ref] for ref in references] 242 )[0]["ref_ngrams"] 243 244 if reference_lprobs is not None: 245 lprobs = reference_lprobs.log_softmax(dim=-1).tolist() 246 else: 247 lprobs = [-math.log(num_references)] * num_references 248 249 acc_ngrams: defaultdict[int, Counter[str]] = defaultdict(Counter) 250 for i, ngrams in enumerate(reference_ngrams): 251 for order, ngram_counts in enumerate(ngrams): 252 for ngram in ngram_counts: 253 # Note: Counter has float values. 254 ngram_counts[ngram] = math.exp( 255 math.log(ngram_counts[ngram]) + lprobs[i] 256 ) 257 acc_ngrams[order] += ngram_counts 258 259 return self.AggregatedReference( 260 [acc_ngrams[order] for order in range(len(acc_ngrams))] 261 ) 262
[docs] 263 def expected_scores_reference_aggregation( 264 self, 265 hypotheses: list[str], 266 references: list[str], 267 source: Optional[str] = None, 268 reference_lprobs: Optional[Tensor] = None, 269 ) -> Tensor: 270 """Calculate the expected scores for each hypothesis. 271 272 Args: 273 hypotheses (list[str]): Hypotheses. 274 references (list[str]): References. 275 source (str, optional): A source. 276 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 277 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 278 279 Returns: 280 Tensor: The expected scores for each hypothesis. 281 """ 282 if self.cfg.fastchrf: 283 if reference_lprobs is not None: 284 raise ValueError("fastchrf does not support model-based aggregation.") 285 286 with timer.measure("expectation"): 287 return self._fastchrf_expected_scores_reference_aggregation( 288 [hypotheses], [references] 289 ).squeeze(0) 290 291 with timer.measure("aggregate/references"): 292 aggregated_reference = self._aggregate_references( 293 references, reference_lprobs=reference_lprobs 294 ) 295 296 expected_scores = torch.zeros((len(hypotheses),)) 297 for i, hypothesis in enumerate(hypotheses): 298 with timer.measure("expectation"): 299 hypothesis = self.scorer._preprocess_segment(hypothesis) 300 301 # Extract character n-grams 302 all_hyp_ngrams = extract_all_char_ngrams( 303 hypothesis, self.scorer.char_order, self.scorer.whitespace 304 ) 305 306 # Check chrF+ mode to see if we'll add word n-grams as well 307 if self.scorer.word_order > 0: 308 # Primitive tokenization: separate out punctuations 309 hwords = self.scorer._remove_punctuation(hypothesis) 310 _range = range(1, self.scorer.word_order + 1) 311 all_hyp_ngrams.extend( 312 [extract_word_ngrams(hwords, n) for n in _range] 313 ) 314 315 stats = [] 316 # Traverse all orders 317 for h, r in zip(all_hyp_ngrams, aggregated_reference.ngrams): 318 stats.extend(self.scorer._get_match_statistics(h, r)) 319 f_score = self.scorer._compute_f_score(stats) 320 expected_scores[i] = f_score 321 322 return expected_scores