Source code for mbrs.metrics.base

  1from __future__ import annotations
  2
  3import abc
  4from dataclasses import dataclass
  5from typing import Optional, Sequence
  6
  7import torch
  8from torch import Tensor
  9
 10from mbrs import functional, registry, timer
 11from mbrs.modules.kmeans import Kmeans
 12
 13
[docs] 14class MetricBase(abc.ABC): 15 """Base metric class.""" 16 17 def __init__(self, cfg: MetricBase.Config): 18 self.cfg = cfg 19 20 HIGHER_IS_BETTER: bool = True 21
[docs] 22 @dataclass 23 class Config: ...
24 25 @property 26 def device(self) -> torch.device: 27 """Returns the device of the metric object.""" 28 return torch.device("cpu")
29 30
[docs] 31class Metric(MetricBase, metaclass=abc.ABCMeta): 32 """Base metric class.""" 33
[docs] 34 @abc.abstractmethod 35 def score( 36 self, hypothesis: str, reference: str, source: Optional[str] = None 37 ) -> float: 38 """Calculate the score of the given hypothesis. 39 40 Args: 41 hypothesis (str): A hypothesis. 42 reference (str): A reference. 43 source (str, optional): A source. 44 45 Returns: 46 float: The score of the given hypothesis. 47 """
48
[docs] 49 def scores( 50 self, 51 hypotheses: list[str], 52 references: list[str], 53 sources: Optional[list[str]] = None, 54 ) -> Tensor: 55 """Calculate the scores of the given hypotheses. 56 57 Args: 58 hypotheses (list[str]): N hypotheses. 59 references (list[str]): N references. 60 sources (list[str], optional): N sources. 61 62 Returns: 63 Tensor: The N scores of the given hypotheses. 64 """ 65 with timer.measure("score") as t: 66 t.set_delta_ncalls(len(hypotheses)) 67 if sources is None: 68 return Tensor( 69 [self.score(hyp, ref) for hyp, ref in zip(hypotheses, references)] 70 ) 71 else: 72 return Tensor( 73 [ 74 self.score(hyp, ref, src) 75 for hyp, ref, src in zip(hypotheses, references, sources) 76 ] 77 )
78
[docs] 79 def pairwise_scores( 80 self, hypotheses: list[str], references: list[str], source: Optional[str] = None 81 ) -> Tensor: 82 """Calculate the pairwise scores. 83 84 Args: 85 hypotheses (list[str]): Hypotheses. 86 references (list[str]): References. 87 source (str, optional): A source. 88 89 Returns: 90 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 91 of hypotheses and `R` is the number of references. 92 """ 93 with timer.measure("score") as t: 94 t.set_delta_ncalls(len(hypotheses) * len(references)) 95 return Tensor( 96 [ 97 [self.score(hyp, ref, source) for ref in references] 98 for hyp in hypotheses 99 ] 100 )
101
[docs] 102 def expected_scores( 103 self, 104 hypotheses: list[str], 105 references: list[str], 106 source: Optional[str] = None, 107 reference_lprobs: Optional[Tensor] = None, 108 ) -> Tensor: 109 """Calculate the expected scores for each hypothesis. 110 111 Args: 112 hypotheses (list[str]): Hypotheses. 113 references (list[str]): References. 114 source (str, optional): A source. 115 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 116 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 117 118 Returns: 119 Tensor: The expected scores for each hypothesis. 120 """ 121 with timer.measure("expectation"): 122 return functional.expectation( 123 self.pairwise_scores(hypotheses, references, source), 124 lprobs=reference_lprobs, 125 )
126
[docs] 127 def corpus_score( 128 self, 129 hypotheses: list[str], 130 references_lists: list[list[str]], 131 sources: Optional[list[str]] = None, 132 ) -> float: 133 """Calculate the corpus-level score. 134 135 Args: 136 hypotheses (list[str]): Hypotheses. 137 references_lists (list[list[str]]): Lists of references. 138 sources (list[str], optional): Sources. 139 140 Returns: 141 float: The corpus score. 142 """ 143 return sum( 144 [ 145 self.scores(hypotheses, references, sources).sum().item() 146 for references in references_lists 147 ] 148 ) / (len(hypotheses) * len(references_lists))
149 150
[docs] 151class MetricAggregatable(Metric, metaclass=abc.ABCMeta): 152 """Base class for aggregatable metrics. 153 154 This class supports reference aggregation.""" 155
[docs] 156 @abc.abstractmethod 157 def expected_scores_reference_aggregation( 158 self, 159 hypotheses: list[str], 160 references: list[str], 161 source: Optional[str] = None, 162 reference_lprobs: Optional[Tensor] = None, 163 ) -> Tensor: 164 """Calculate the expected scores for each hypothesis. 165 166 Args: 167 hypotheses (list[str]): Hypotheses. 168 references (list[str]): References. 169 source (str, optional): A source. 170 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 171 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 172 173 Returns: 174 Tensor: The expected scores for each hypothesis. 175 """
176 177
[docs] 178class MetricCacheable(Metric, metaclass=abc.ABCMeta): 179 """Base class for cacheable metrics. 180 181 This class supports to cache intermediate representations of sentences.""" 182
[docs] 183 @dataclass 184 class Cache(metaclass=abc.ABCMeta): 185 """Intermediate representations of sentences.""" 186 187 @abc.abstractmethod 188 def __len__(self) -> int: 189 """Return the length of cache.""" 190 191 @abc.abstractmethod 192 def __getitem__( 193 self, key: int | Sequence[int] | slice | Tensor 194 ) -> MetricCacheable.Cache: 195 """Get the items.""" 196
[docs] 197 @abc.abstractmethod 198 def repeat(self, n: int) -> MetricCacheable.Cache: 199 """Repeat the representations by n times. 200 201 Args: 202 n (int): The number of repetition. 203 204 Returns: 205 Cache: The repeated cache. 206 """
207 208 @property 209 @abc.abstractmethod 210 def embed_dim(self) -> int: 211 """Return the size of embedding dimension.""" 212
[docs] 213 @abc.abstractmethod 214 def encode(self, sentences: list[str]) -> Cache: 215 """Encode the given sentences into their intermediate representations. 216 217 Args: 218 sentences (list[str]): Input sentences. 219 220 Returns: 221 MetricCacheable.Cache: Intermediate representations. 222 """
223
[docs] 224 @abc.abstractmethod 225 def out_proj( 226 self, 227 hypotheses_ir: Cache, 228 references_ir: Cache, 229 sources_ir: Optional[Cache] = None, 230 ) -> Tensor: 231 """Forward the output projection layer. 232 233 Args: 234 hypotheses_ir (Cache): N intermediate representations of hypotheses. 235 references_ir (Cache): N intermediate representations of references. 236 sources_ir (Cache, optional): N intermediate representations of sources. 237 238 Returns: 239 Tensor: N scores. 240 """
241
[docs] 242 def scores_from_ir( 243 self, 244 hypotheses_ir: Cache, 245 references_ir: Cache, 246 sources_ir: Optional[Cache] = None, 247 ) -> Tensor: 248 """Calculate the scores of the given hypotheses from the intermediate representations. 249 250 Args: 251 hypotheses_ir (Cache): N hypotheses. 252 references_ir (Cache): N references. 253 sources_ir (Cache, optional): N sources. 254 255 Returns: 256 Tensor: The N scores of the given hypotheses. 257 """ 258 H = len(hypotheses_ir) 259 with timer.measure("score") as t: 260 t.set_delta_ncalls(H) 261 if sources_ir is None: 262 return self.out_proj(hypotheses_ir, references_ir) 263 else: 264 return self.out_proj(hypotheses_ir, references_ir, sources_ir)
265
[docs] 266 def score( 267 self, 268 hypothesis: str, 269 reference: str, 270 source: Optional[str] = None, 271 ) -> float: 272 """Calculate the score of the given hypothesis. 273 274 Args: 275 hypothesis (str): A hypothesis. 276 reference (str): A reference. 277 source (str, optional): A source. 278 279 Returns: 280 float: The score of the given hypothesis. 281 """ 282 return self.scores( 283 [hypothesis], 284 [reference], 285 [source] if source is not None else None, 286 ).item()
287
[docs] 288 def scores( 289 self, 290 hypotheses: list[str], 291 references: list[str], 292 sources: Optional[list[str]] = None, 293 ) -> Tensor: 294 """Calculate the scores of the given hypotheses. 295 296 Args: 297 hypotheses (list[str]): N hypotheses. 298 references (list[str]): N references. 299 source (list[str], optional): N sources. 300 301 Returns: 302 Tensor: The N scores of the given hypotheses. 303 """ 304 305 return self.scores_from_ir( 306 self.encode(hypotheses), 307 self.encode(references), 308 self.encode(sources) if sources is not None else None, 309 )
310
[docs] 311 def pairwise_scores_from_ir( 312 self, 313 hypotheses_ir: Cache, 314 references_ir: Cache, 315 source_ir: Optional[Cache] = None, 316 ) -> Tensor: 317 """Calculate the pairwise scores from the intermediate representations. 318 319 Args: 320 hypotheses_ir (Cache): Hypotheses. 321 references_ir (Cache): References. 322 source_ir (Cache, optional): A source. 323 324 Returns: 325 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 326 of hypotheses and `R` is the number of references. 327 """ 328 H = len(hypotheses_ir) 329 R = len(references_ir) 330 if source_ir is not None: 331 source_ir = source_ir.repeat(H) 332 333 scores = [] 334 for i in range(R): 335 with timer.measure("score") as t: 336 t.set_delta_ncalls(H) 337 scores.append( 338 self.scores_from_ir( 339 hypotheses_ir, references_ir[i].repeat(H), source_ir 340 )[:, None] 341 ) 342 return torch.cat(scores, dim=-1)
343
[docs] 344 def pairwise_scores( 345 self, hypotheses: list[str], references: list[str], source: Optional[str] = None 346 ) -> Tensor: 347 """Calculate the pairwise scores. 348 349 Args: 350 hypotheses (list[str]): Hypotheses. 351 references (list[str]): References. 352 source (str, optional): A source. 353 354 Returns: 355 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 356 of hypotheses and `R` is the number of references. 357 """ 358 with timer.measure("encode/hypotheses"): 359 hypotheses_ir = self.encode(hypotheses) 360 if hypotheses == references: 361 references_ir = hypotheses_ir 362 else: 363 with timer.measure("encode/references"): 364 references_ir = self.encode(references) 365 if source is None: 366 source_ir = None 367 else: 368 with timer.measure("encode/source"): 369 source_ir = self.encode([source]) 370 return self.pairwise_scores_from_ir(hypotheses_ir, references_ir, source_ir)
371 372
[docs] 373class MetricAggregatableCache( 374 MetricAggregatable, MetricCacheable, metaclass=abc.ABCMeta 375): 376 """Base class for metrics that can aggregate the cache. 377 378 This class supports to aggregate intermediate representations of sentences.""" 379
[docs] 380 @dataclass 381 class Cache(MetricCacheable.Cache, metaclass=abc.ABCMeta): 382 """Intermediate representations of sentences.""" 383
[docs] 384 @abc.abstractmethod 385 def aggregate( 386 self, reference_lprobs: Optional[Tensor] = None 387 ) -> MetricAggregatableCache.Cache: 388 """Aggregate the cached representations. 389 390 Args: 391 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 392 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 393 394 Returns: 395 Cache: An aggregated representation. 396 """
397
[docs] 398 def cluster( 399 self, kmeans: Kmeans 400 ) -> tuple[MetricAggregatableCache.Cache, Tensor]: 401 """Cluster the cached representations. 402 403 Args: 404 kmeans (Kmeans): k-means class to perform clustering. 405 406 Returns: 407 tuple[Cache, Tensor]: 408 - Cache: Centroid representations. 409 - Tensor: N assigned IDs. 410 """ 411 raise NotImplementedError(type(self).__name__)
412
[docs] 413 def expected_scores_reference_aggregation( 414 self, 415 hypotheses: list[str], 416 references: list[str], 417 source: Optional[str] = None, 418 reference_lprobs: Optional[Tensor] = None, 419 ) -> Tensor: 420 """Calculate the expected scores for each hypothesis. 421 422 Args: 423 hypotheses (list[str]): Hypotheses. 424 references (list[str]): References. 425 source (str, optional): A source. 426 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 427 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 428 429 Returns: 430 Tensor: The expected scores for each hypothesis. 431 """ 432 with timer.measure("encode/hypotheses"): 433 hypotheses_ir = self.encode(hypotheses) 434 if hypotheses == references: 435 references_ir = hypotheses_ir 436 else: 437 with timer.measure("encode/references"): 438 references_ir = self.encode(references) 439 if source is None: 440 source_ir = None 441 else: 442 with timer.measure("encode/source"): 443 source_ir = self.encode([source]) 444 445 with timer.measure("aggregate/references"): 446 aggregated_reference_ir = references_ir.aggregate(reference_lprobs) 447 448 with timer.measure("expectation"): 449 return self.pairwise_scores_from_ir( 450 hypotheses_ir, aggregated_reference_ir, source_ir=source_ir 451 ).mean(dim=-1)
452 453
[docs] 454class MetricReferenceless(MetricBase, metaclass=abc.ABCMeta): 455 """Base class for reference-less metrics like quality estimation.""" 456
[docs] 457 @abc.abstractmethod 458 def score(self, hypothesis: str, source: str) -> float: 459 """Calculate the score of the given hypothesis. 460 461 Args: 462 hypothesis (str): A hypothesis. 463 source (str): A source. 464 465 Returns: 466 float: The score of the given hypothesis. 467 """
468
[docs] 469 def scores(self, hypotheses: list[str], sources: list[str]) -> Tensor: 470 """Calculate the scores of hypotheses. 471 472 Args: 473 hypotheses (list[str]): N hypotheses. 474 sources (list[str]): N sources. 475 476 Returns: 477 Tensor: The scores of hypotheses. 478 """ 479 return Tensor([self.score(hyp, src) for hyp, src in zip(hypotheses, sources)])
480
[docs] 481 def corpus_score(self, hypotheses: list[str], sources: list[str]) -> float: 482 """Calculate the corpus-level score. 483 484 Args: 485 hypotheses (list[str]): Hypotheses. 486 sources (list[str]): Sources. 487 488 Returns: 489 float: The corpus score. 490 """ 491 return self.scores(hypotheses, sources=sources).mean().cpu().float().item()
492 493 494register, get_metric = registry.Registry(Metric | MetricReferenceless).get_closure()