1from __future__ import annotations
2
3import abc
4from dataclasses import dataclass
5from typing import Optional, Sequence
6
7import torch
8from torch import Tensor
9
10from mbrs import functional, registry, timer
11from mbrs.modules.kmeans import Kmeans
12
13
[docs]
14class MetricBase(abc.ABC):
15 """Base metric class."""
16
17 def __init__(self, cfg: MetricBase.Config):
18 self.cfg = cfg
19
20 HIGHER_IS_BETTER: bool = True
21
[docs]
22 @dataclass
23 class Config: ...
24
25 @property
26 def device(self) -> torch.device:
27 """Returns the device of the metric object."""
28 return torch.device("cpu")
29
30
[docs]
31class Metric(MetricBase, metaclass=abc.ABCMeta):
32 """Base metric class."""
33
[docs]
34 @abc.abstractmethod
35 def score(
36 self, hypothesis: str, reference: str, source: Optional[str] = None
37 ) -> float:
38 """Calculate the score of the given hypothesis.
39
40 Args:
41 hypothesis (str): A hypothesis.
42 reference (str): A reference.
43 source (str, optional): A source.
44
45 Returns:
46 float: The score of the given hypothesis.
47 """
48
[docs]
49 def scores(
50 self,
51 hypotheses: list[str],
52 references: list[str],
53 sources: Optional[list[str]] = None,
54 ) -> Tensor:
55 """Calculate the scores of the given hypotheses.
56
57 Args:
58 hypotheses (list[str]): N hypotheses.
59 references (list[str]): N references.
60 sources (list[str], optional): N sources.
61
62 Returns:
63 Tensor: The N scores of the given hypotheses.
64 """
65 with timer.measure("score") as t:
66 t.set_delta_ncalls(len(hypotheses))
67 if sources is None:
68 return Tensor(
69 [self.score(hyp, ref) for hyp, ref in zip(hypotheses, references)]
70 )
71 else:
72 return Tensor(
73 [
74 self.score(hyp, ref, src)
75 for hyp, ref, src in zip(hypotheses, references, sources)
76 ]
77 )
78
[docs]
79 def pairwise_scores(
80 self, hypotheses: list[str], references: list[str], source: Optional[str] = None
81 ) -> Tensor:
82 """Calculate the pairwise scores.
83
84 Args:
85 hypotheses (list[str]): Hypotheses.
86 references (list[str]): References.
87 source (str, optional): A source.
88
89 Returns:
90 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
91 of hypotheses and `R` is the number of references.
92 """
93 with timer.measure("score") as t:
94 t.set_delta_ncalls(len(hypotheses) * len(references))
95 return Tensor(
96 [
97 [self.score(hyp, ref, source) for ref in references]
98 for hyp in hypotheses
99 ]
100 )
101
[docs]
102 def expected_scores(
103 self,
104 hypotheses: list[str],
105 references: list[str],
106 source: Optional[str] = None,
107 reference_lprobs: Optional[Tensor] = None,
108 ) -> Tensor:
109 """Calculate the expected scores for each hypothesis.
110
111 Args:
112 hypotheses (list[str]): Hypotheses.
113 references (list[str]): References.
114 source (str, optional): A source.
115 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
116 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
117
118 Returns:
119 Tensor: The expected scores for each hypothesis.
120 """
121 with timer.measure("expectation"):
122 return functional.expectation(
123 self.pairwise_scores(hypotheses, references, source),
124 lprobs=reference_lprobs,
125 )
126
[docs]
127 def corpus_score(
128 self,
129 hypotheses: list[str],
130 references_lists: list[list[str]],
131 sources: Optional[list[str]] = None,
132 ) -> float:
133 """Calculate the corpus-level score.
134
135 Args:
136 hypotheses (list[str]): Hypotheses.
137 references_lists (list[list[str]]): Lists of references.
138 sources (list[str], optional): Sources.
139
140 Returns:
141 float: The corpus score.
142 """
143 return sum(
144 [
145 self.scores(hypotheses, references, sources).sum().item()
146 for references in references_lists
147 ]
148 ) / (len(hypotheses) * len(references_lists))
149
150
[docs]
151class MetricAggregatable(Metric, metaclass=abc.ABCMeta):
152 """Base class for aggregatable metrics.
153
154 This class supports reference aggregation."""
155
[docs]
156 @abc.abstractmethod
157 def expected_scores_reference_aggregation(
158 self,
159 hypotheses: list[str],
160 references: list[str],
161 source: Optional[str] = None,
162 reference_lprobs: Optional[Tensor] = None,
163 ) -> Tensor:
164 """Calculate the expected scores for each hypothesis.
165
166 Args:
167 hypotheses (list[str]): Hypotheses.
168 references (list[str]): References.
169 source (str, optional): A source.
170 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
171 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
172
173 Returns:
174 Tensor: The expected scores for each hypothesis.
175 """
176
177
[docs]
178class MetricCacheable(Metric, metaclass=abc.ABCMeta):
179 """Base class for cacheable metrics.
180
181 This class supports to cache intermediate representations of sentences."""
182
[docs]
183 @dataclass
184 class Cache(metaclass=abc.ABCMeta):
185 """Intermediate representations of sentences."""
186
187 @abc.abstractmethod
188 def __len__(self) -> int:
189 """Return the length of cache."""
190
191 @abc.abstractmethod
192 def __getitem__(
193 self, key: int | Sequence[int] | slice | Tensor
194 ) -> MetricCacheable.Cache:
195 """Get the items."""
196
[docs]
197 @abc.abstractmethod
198 def repeat(self, n: int) -> MetricCacheable.Cache:
199 """Repeat the representations by n times.
200
201 Args:
202 n (int): The number of repetition.
203
204 Returns:
205 Cache: The repeated cache.
206 """
207
208 @property
209 @abc.abstractmethod
210 def embed_dim(self) -> int:
211 """Return the size of embedding dimension."""
212
[docs]
213 @abc.abstractmethod
214 def encode(self, sentences: list[str]) -> Cache:
215 """Encode the given sentences into their intermediate representations.
216
217 Args:
218 sentences (list[str]): Input sentences.
219
220 Returns:
221 MetricCacheable.Cache: Intermediate representations.
222 """
223
[docs]
224 @abc.abstractmethod
225 def out_proj(
226 self,
227 hypotheses_ir: Cache,
228 references_ir: Cache,
229 sources_ir: Optional[Cache] = None,
230 ) -> Tensor:
231 """Forward the output projection layer.
232
233 Args:
234 hypotheses_ir (Cache): N intermediate representations of hypotheses.
235 references_ir (Cache): N intermediate representations of references.
236 sources_ir (Cache, optional): N intermediate representations of sources.
237
238 Returns:
239 Tensor: N scores.
240 """
241
[docs]
242 def scores_from_ir(
243 self,
244 hypotheses_ir: Cache,
245 references_ir: Cache,
246 sources_ir: Optional[Cache] = None,
247 ) -> Tensor:
248 """Calculate the scores of the given hypotheses from the intermediate representations.
249
250 Args:
251 hypotheses_ir (Cache): N hypotheses.
252 references_ir (Cache): N references.
253 sources_ir (Cache, optional): N sources.
254
255 Returns:
256 Tensor: The N scores of the given hypotheses.
257 """
258 H = len(hypotheses_ir)
259 with timer.measure("score") as t:
260 t.set_delta_ncalls(H)
261 if sources_ir is None:
262 return self.out_proj(hypotheses_ir, references_ir)
263 else:
264 return self.out_proj(hypotheses_ir, references_ir, sources_ir)
265
[docs]
266 def score(
267 self,
268 hypothesis: str,
269 reference: str,
270 source: Optional[str] = None,
271 ) -> float:
272 """Calculate the score of the given hypothesis.
273
274 Args:
275 hypothesis (str): A hypothesis.
276 reference (str): A reference.
277 source (str, optional): A source.
278
279 Returns:
280 float: The score of the given hypothesis.
281 """
282 return self.scores(
283 [hypothesis],
284 [reference],
285 [source] if source is not None else None,
286 ).item()
287
[docs]
288 def scores(
289 self,
290 hypotheses: list[str],
291 references: list[str],
292 sources: Optional[list[str]] = None,
293 ) -> Tensor:
294 """Calculate the scores of the given hypotheses.
295
296 Args:
297 hypotheses (list[str]): N hypotheses.
298 references (list[str]): N references.
299 source (list[str], optional): N sources.
300
301 Returns:
302 Tensor: The N scores of the given hypotheses.
303 """
304
305 return self.scores_from_ir(
306 self.encode(hypotheses),
307 self.encode(references),
308 self.encode(sources) if sources is not None else None,
309 )
310
[docs]
311 def pairwise_scores_from_ir(
312 self,
313 hypotheses_ir: Cache,
314 references_ir: Cache,
315 source_ir: Optional[Cache] = None,
316 ) -> Tensor:
317 """Calculate the pairwise scores from the intermediate representations.
318
319 Args:
320 hypotheses_ir (Cache): Hypotheses.
321 references_ir (Cache): References.
322 source_ir (Cache, optional): A source.
323
324 Returns:
325 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
326 of hypotheses and `R` is the number of references.
327 """
328 H = len(hypotheses_ir)
329 R = len(references_ir)
330 if source_ir is not None:
331 source_ir = source_ir.repeat(H)
332
333 scores = []
334 for i in range(R):
335 with timer.measure("score") as t:
336 t.set_delta_ncalls(H)
337 scores.append(
338 self.scores_from_ir(
339 hypotheses_ir, references_ir[i].repeat(H), source_ir
340 )[:, None]
341 )
342 return torch.cat(scores, dim=-1)
343
[docs]
344 def pairwise_scores(
345 self, hypotheses: list[str], references: list[str], source: Optional[str] = None
346 ) -> Tensor:
347 """Calculate the pairwise scores.
348
349 Args:
350 hypotheses (list[str]): Hypotheses.
351 references (list[str]): References.
352 source (str, optional): A source.
353
354 Returns:
355 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
356 of hypotheses and `R` is the number of references.
357 """
358 with timer.measure("encode/hypotheses"):
359 hypotheses_ir = self.encode(hypotheses)
360 if hypotheses == references:
361 references_ir = hypotheses_ir
362 else:
363 with timer.measure("encode/references"):
364 references_ir = self.encode(references)
365 if source is None:
366 source_ir = None
367 else:
368 with timer.measure("encode/source"):
369 source_ir = self.encode([source])
370 return self.pairwise_scores_from_ir(hypotheses_ir, references_ir, source_ir)
371
372
[docs]
373class MetricAggregatableCache(
374 MetricAggregatable, MetricCacheable, metaclass=abc.ABCMeta
375):
376 """Base class for metrics that can aggregate the cache.
377
378 This class supports to aggregate intermediate representations of sentences."""
379
[docs]
380 @dataclass
381 class Cache(MetricCacheable.Cache, metaclass=abc.ABCMeta):
382 """Intermediate representations of sentences."""
383
[docs]
384 @abc.abstractmethod
385 def aggregate(
386 self, reference_lprobs: Optional[Tensor] = None
387 ) -> MetricAggregatableCache.Cache:
388 """Aggregate the cached representations.
389
390 Args:
391 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
392 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
393
394 Returns:
395 Cache: An aggregated representation.
396 """
397
[docs]
398 def cluster(
399 self, kmeans: Kmeans
400 ) -> tuple[MetricAggregatableCache.Cache, Tensor]:
401 """Cluster the cached representations.
402
403 Args:
404 kmeans (Kmeans): k-means class to perform clustering.
405
406 Returns:
407 tuple[Cache, Tensor]:
408 - Cache: Centroid representations.
409 - Tensor: N assigned IDs.
410 """
411 raise NotImplementedError(type(self).__name__)
412
[docs]
413 def expected_scores_reference_aggregation(
414 self,
415 hypotheses: list[str],
416 references: list[str],
417 source: Optional[str] = None,
418 reference_lprobs: Optional[Tensor] = None,
419 ) -> Tensor:
420 """Calculate the expected scores for each hypothesis.
421
422 Args:
423 hypotheses (list[str]): Hypotheses.
424 references (list[str]): References.
425 source (str, optional): A source.
426 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
427 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
428
429 Returns:
430 Tensor: The expected scores for each hypothesis.
431 """
432 with timer.measure("encode/hypotheses"):
433 hypotheses_ir = self.encode(hypotheses)
434 if hypotheses == references:
435 references_ir = hypotheses_ir
436 else:
437 with timer.measure("encode/references"):
438 references_ir = self.encode(references)
439 if source is None:
440 source_ir = None
441 else:
442 with timer.measure("encode/source"):
443 source_ir = self.encode([source])
444
445 with timer.measure("aggregate/references"):
446 aggregated_reference_ir = references_ir.aggregate(reference_lprobs)
447
448 with timer.measure("expectation"):
449 return self.pairwise_scores_from_ir(
450 hypotheses_ir, aggregated_reference_ir, source_ir=source_ir
451 ).mean(dim=-1)
452
453
[docs]
454class MetricReferenceless(MetricBase, metaclass=abc.ABCMeta):
455 """Base class for reference-less metrics like quality estimation."""
456
[docs]
457 @abc.abstractmethod
458 def score(self, hypothesis: str, source: str) -> float:
459 """Calculate the score of the given hypothesis.
460
461 Args:
462 hypothesis (str): A hypothesis.
463 source (str): A source.
464
465 Returns:
466 float: The score of the given hypothesis.
467 """
468
[docs]
469 def scores(self, hypotheses: list[str], sources: list[str]) -> Tensor:
470 """Calculate the scores of hypotheses.
471
472 Args:
473 hypotheses (list[str]): N hypotheses.
474 sources (list[str]): N sources.
475
476 Returns:
477 Tensor: The scores of hypotheses.
478 """
479 return Tensor([self.score(hyp, src) for hyp, src in zip(hypotheses, sources)])
480
[docs]
481 def corpus_score(self, hypotheses: list[str], sources: list[str]) -> float:
482 """Calculate the corpus-level score.
483
484 Args:
485 hypotheses (list[str]): Hypotheses.
486 sources (list[str]): Sources.
487
488 Returns:
489 float: The corpus score.
490 """
491 return self.scores(hypotheses, sources=sources).mean().cpu().float().item()
492
493
494register, get_metric = registry.Registry(Metric | MetricReferenceless).get_closure()