Source code for mbrs.metrics.ter

  1from __future__ import annotations
  2
  3import concurrent.futures
  4import itertools
  5import math
  6from dataclasses import dataclass
  7from typing import Optional
  8
  9from sacrebleu.metrics.ter import TER
 10from torch import Tensor
 11
 12from mbrs import timer
 13
 14from . import Metric, register
 15
 16
[docs] 17@register("ter") 18class MetricTER(Metric): 19 """TER metric class.""" 20 21 HIGHER_IS_BETTER: bool = False 22
[docs] 23 @dataclass 24 class Config(Metric.Config): 25 """TER metric configuration. 26 27 - normalized (bool): Enable character normalization. 28 By default, normalizes a couple of things such as newlines being stripped, 29 retrieving XML encoded characters, and fixing tokenization for punctuation. 30 When 'asian_support' is enabled, also normalizes specific Asian (CJK) 31 character sequences, i.e. split them down to the character level. 32 - no_punct (bool): Remove punctuation. Can be used in conjunction with 33 'asian_support' to also remove typical punctuation markers in Asian languages 34 (CJK). 35 - asian_support (bool): Enable special treatment of Asian characters. 36 This option only has an effect when 'normalized' and/or 'no_punct' is enabled. 37 If 'normalized' is also enabled, then Asian (CJK) characters are split down to 38 the character level. If 'no_punct' is enabled alongside 'asian_support', 39 specific unicode ranges for CJK and full-width punctuations are also removed. 40 - case_sensitive (bool): If `True`, does not lowercase sentences. 41 - num_workers (int): Number of workers for multiprocessing. 42 """ 43 44 normalized: bool = False 45 no_punct: bool = False 46 asian_support: bool = False 47 case_sensitive: bool = False 48 num_workers: int = 8
49 50 cfg: Config 51 52 def __init__(self, cfg: MetricTER.Config): 53 super().__init__(cfg) 54 self.scorer = TER( 55 normalized=cfg.normalized, 56 no_punct=cfg.no_punct, 57 asian_support=cfg.asian_support, 58 case_sensitive=cfg.case_sensitive, 59 ) 60
[docs] 61 def score(self, hypothesis: str, reference: str, *_, **__) -> float: 62 """Calculate the score of the given hypothesis. 63 64 Args: 65 hypothesis (str): Hypothesis. 66 reference (str): Reference. 67 68 Returns: 69 float: The score of the given hypothesis. 70 """ 71 return self.scorer.sentence_score(hypothesis, [reference]).score
72
[docs] 73 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor: 74 """Calculate the scores of the given hypotheses. 75 76 Args: 77 hypotheses (list[str]): N hypotheses. 78 references (list[str]): N references. 79 80 Returns: 81 Tensor: The N scores of the given hypotheses. 82 """ 83 with concurrent.futures.ProcessPoolExecutor( 84 max_workers=self.cfg.num_workers 85 ) as executor: 86 with timer.measure("score") as t: 87 t.set_delta_ncalls(len(hypotheses)) 88 return Tensor( 89 list( 90 executor.map( 91 self.score, 92 hypotheses, 93 references, 94 chunksize=math.ceil(len(hypotheses) / self.cfg.num_workers), 95 ) 96 ) 97 )
98
[docs] 99 def pairwise_scores( 100 self, hypotheses: list[str], references: list[str], *_, **__ 101 ) -> Tensor: 102 """Calculate the pairwise scores. 103 104 Args: 105 hypotheses (list[str]): Hypotheses. 106 references (list[str]): References. 107 108 Returns: 109 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 110 of hypotheses and `R` is the number of references. 111 """ 112 with concurrent.futures.ProcessPoolExecutor( 113 max_workers=self.cfg.num_workers 114 ) as executor: 115 with timer.measure("score") as t: 116 t.set_delta_ncalls(len(hypotheses) * len(references)) 117 118 return Tensor( 119 list( 120 executor.map( 121 self.score, 122 *zip(*itertools.product(hypotheses, references)), 123 chunksize=len(hypotheses), 124 ) 125 ) 126 ).view(len(hypotheses), len(references))
127
[docs] 128 def corpus_score( 129 self, 130 hypotheses: list[str], 131 references_lists: list[list[str]], 132 sources: Optional[list[str]] = None, 133 ) -> float: 134 """Calculate the corpus-level score. 135 136 Args: 137 hypotheses (list[str]): Hypotheses. 138 references_lists (list[list[str]]): Lists of references. 139 sources (list[str], optional): Sources. 140 141 Returns: 142 float: The corpus score. 143 """ 144 return self.scorer.corpus_score(hypotheses, references_lists).score