Source code for mbrs.metrics.ter
1from __future__ import annotations
2
3import concurrent.futures
4import itertools
5import math
6from dataclasses import dataclass
7from typing import Optional
8
9from sacrebleu.metrics.ter import TER
10from torch import Tensor
11
12from mbrs import timer
13
14from . import Metric, register
15
16
[docs]
17@register("ter")
18class MetricTER(Metric):
19 """TER metric class."""
20
21 HIGHER_IS_BETTER: bool = False
22
[docs]
23 @dataclass
24 class Config(Metric.Config):
25 """TER metric configuration.
26
27 - normalized (bool): Enable character normalization.
28 By default, normalizes a couple of things such as newlines being stripped,
29 retrieving XML encoded characters, and fixing tokenization for punctuation.
30 When 'asian_support' is enabled, also normalizes specific Asian (CJK)
31 character sequences, i.e. split them down to the character level.
32 - no_punct (bool): Remove punctuation. Can be used in conjunction with
33 'asian_support' to also remove typical punctuation markers in Asian languages
34 (CJK).
35 - asian_support (bool): Enable special treatment of Asian characters.
36 This option only has an effect when 'normalized' and/or 'no_punct' is enabled.
37 If 'normalized' is also enabled, then Asian (CJK) characters are split down to
38 the character level. If 'no_punct' is enabled alongside 'asian_support',
39 specific unicode ranges for CJK and full-width punctuations are also removed.
40 - case_sensitive (bool): If `True`, does not lowercase sentences.
41 - num_workers (int): Number of workers for multiprocessing.
42 """
43
44 normalized: bool = False
45 no_punct: bool = False
46 asian_support: bool = False
47 case_sensitive: bool = False
48 num_workers: int = 8
49
50 cfg: Config
51
52 def __init__(self, cfg: MetricTER.Config):
53 super().__init__(cfg)
54 self.scorer = TER(
55 normalized=cfg.normalized,
56 no_punct=cfg.no_punct,
57 asian_support=cfg.asian_support,
58 case_sensitive=cfg.case_sensitive,
59 )
60
[docs]
61 def score(self, hypothesis: str, reference: str, *_, **__) -> float:
62 """Calculate the score of the given hypothesis.
63
64 Args:
65 hypothesis (str): Hypothesis.
66 reference (str): Reference.
67
68 Returns:
69 float: The score of the given hypothesis.
70 """
71 return self.scorer.sentence_score(hypothesis, [reference]).score
72
[docs]
73 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor:
74 """Calculate the scores of the given hypotheses.
75
76 Args:
77 hypotheses (list[str]): N hypotheses.
78 references (list[str]): N references.
79
80 Returns:
81 Tensor: The N scores of the given hypotheses.
82 """
83 with concurrent.futures.ProcessPoolExecutor(
84 max_workers=self.cfg.num_workers
85 ) as executor:
86 with timer.measure("score") as t:
87 t.set_delta_ncalls(len(hypotheses))
88 return Tensor(
89 list(
90 executor.map(
91 self.score,
92 hypotheses,
93 references,
94 chunksize=math.ceil(len(hypotheses) / self.cfg.num_workers),
95 )
96 )
97 )
98
[docs]
99 def pairwise_scores(
100 self, hypotheses: list[str], references: list[str], *_, **__
101 ) -> Tensor:
102 """Calculate the pairwise scores.
103
104 Args:
105 hypotheses (list[str]): Hypotheses.
106 references (list[str]): References.
107
108 Returns:
109 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
110 of hypotheses and `R` is the number of references.
111 """
112 with concurrent.futures.ProcessPoolExecutor(
113 max_workers=self.cfg.num_workers
114 ) as executor:
115 with timer.measure("score") as t:
116 t.set_delta_ncalls(len(hypotheses) * len(references))
117
118 return Tensor(
119 list(
120 executor.map(
121 self.score,
122 *zip(*itertools.product(hypotheses, references)),
123 chunksize=len(hypotheses),
124 )
125 )
126 ).view(len(hypotheses), len(references))
127
[docs]
128 def corpus_score(
129 self,
130 hypotheses: list[str],
131 references_lists: list[list[str]],
132 sources: Optional[list[str]] = None,
133 ) -> float:
134 """Calculate the corpus-level score.
135
136 Args:
137 hypotheses (list[str]): Hypotheses.
138 references_lists (list[list[str]]): Lists of references.
139 sources (list[str], optional): Sources.
140
141 Returns:
142 float: The corpus score.
143 """
144 return self.scorer.corpus_score(hypotheses, references_lists).score