Source code for mbrs.metrics.bertscore

  1from __future__ import annotations
  2
  3import enum
  4from collections import defaultdict
  5from dataclasses import dataclass
  6from typing import Optional, Sequence
  7
  8import bert_score
  9import bert_score.utils
 10import torch
 11import transformers
 12from bert_score import BERTScorer
 13from simple_parsing.helpers.fields import choice
 14from torch import Tensor
 15from transformers.models.gpt2 import GPT2Tokenizer
 16from transformers.models.roberta import RobertaTokenizer
 17from transformers.tokenization_utils import (
 18    BatchEncoding,
 19    EncodedInput,
 20    PreTrainedTokenizerBase,
 21)
 22
 23from mbrs.metrics.base import MetricCacheable
 24
 25from . import Metric, register
 26
 27transformers.logging.set_verbosity_error()
 28
 29
[docs] 30class BERTScoreScoreType(int, enum.Enum): 31 precision = 0 32 recall = 1 33 f1 = 2
34 35
[docs] 36@register("bertscore") 37class MetricBERTScore(MetricCacheable): 38 """BERTScore metric class.""" 39
[docs] 40 @dataclass 41 class Config(Metric.Config): 42 """BERTScore metric configuration. 43 44 - score_type (BERTScoreScoreType): The output score type, i.e., 45 precision, recall, or f1. 46 - model_type (str): Contexual embedding model specification, default using the 47 suggested model for the target langauge; has to specify at least one of 48 `model_type` or `lang`. 49 - num_layers (int): The layer of representation to use. Default using the number 50 of layer tuned on WMT16 correlation data. 51 - idf (bool): A booling to specify whether to use idf or not. (This should be 52 True even if `idf_sents` is given.) 53 - idf_sents (list[str]): List of sentences used to compute the idf weights. 54 - batch_size (int): Bert score processing batch size 55 - nthreads (int): Number of threads. 56 - lang (str): Language of the sentences; has to specify at least one of 57 `model_type` or `lang`. `lang` needs to be specified when 58 `rescale_with_baseline` is True. 59 - rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline. 60 - baseline_path (str): Customized baseline file. 61 - use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. 62 - fp16 (bool): Use float16 for the forward computation. 63 - bf16 (bool): Use bfloat16 for the forward computation. 64 - cpu (bool): Use CPU for the forward computation. 65 """ 66 67 score_type: BERTScoreScoreType = choice( 68 BERTScoreScoreType, default=BERTScoreScoreType.f1 69 ) 70 model_type: Optional[str] = None 71 num_layers: Optional[int] = None 72 batch_size: int = 64 73 nthreads: int = 4 74 idf: bool = False 75 idf_sents: Optional[list[str]] = None 76 lang: Optional[str] = None 77 rescale_with_baseline: bool = False 78 baseline_path: Optional[str] = None 79 use_fast_tokenizer: bool = False 80 fp16: bool = False 81 bf16: bool = False 82 cpu: bool = False
83
[docs] 84 @dataclass 85 class Cache(MetricCacheable.Cache): 86 """Intermediate representations of sentences. 87 88 - embeddings (list[Tensor]): A list of token embeddings of shape `(T, D)`, 89 where `T` is the length of sequence, and `D` is a size of the embedding. 90 - idf_weights (list[Tensor]): A list of IDF weights of shape `(T,)`. 91 """ 92 93 embeddings: list[Tensor] 94 idf_weights: list[Tensor] 95 96 def __len__(self) -> int: 97 """Return the length of cache.""" 98 return len(self.embeddings) 99 100 def __getitem__( 101 self, key: int | Sequence[int] | slice | Tensor 102 ) -> MetricBERTScore.Cache: 103 """Get the items.""" 104 if isinstance(key, Tensor): 105 dtype = key.dtype 106 key = key.tolist() 107 if dtype == torch.bool: 108 return type(self)( 109 [self.embeddings[k] for k in key if k], 110 [self.idf_weights[k] for k in key if k], 111 ) 112 return type(self)( 113 [self.embeddings[k] for k in key], 114 [self.idf_weights[k] for k in key], 115 ) 116 elif isinstance(key, Sequence): 117 return type(self)( 118 [self.embeddings[k] for k in key], 119 [self.idf_weights[k] for k in key], 120 ) 121 elif isinstance(key, slice): 122 return type(self)(self.embeddings[key], self.idf_weights[key]) 123 else: 124 return type(self)([self.embeddings[key]], [self.idf_weights[key]]) 125
[docs] 126 def repeat(self, n: int) -> MetricBERTScore.Cache: 127 """Repeat the representations by n times. 128 129 Args: 130 n (int): The number of repetition. 131 132 Returns: 133 Cache: The repeated cache. 134 """ 135 return type(self)(self.embeddings * n, self.idf_weights * n)
136 137 cfg: MetricBERTScore.Config 138 139 def __init__(self, cfg: MetricBERTScore.Config): 140 super().__init__(cfg) 141 self.scorer: BERTScorer = BERTScorer( 142 model_type=cfg.model_type, 143 num_layers=cfg.num_layers, 144 batch_size=cfg.batch_size, 145 nthreads=cfg.nthreads, 146 all_layers=False, 147 idf=cfg.idf, 148 idf_sents=cfg.idf_sents, 149 device="cpu" if cfg.cpu else None, 150 lang=cfg.lang, 151 rescale_with_baseline=cfg.rescale_with_baseline, 152 baseline_path=cfg.baseline_path, 153 use_fast_tokenizer=cfg.use_fast_tokenizer, 154 ) 155 self.tokenizer: PreTrainedTokenizerBase = self.scorer._tokenizer 156 self.model = self.scorer._model 157 self.model.eval() 158 for param in self.model.parameters(): 159 param.requires_grad = False 160 161 if not cfg.cpu and torch.cuda.is_available(): 162 if cfg.fp16: 163 self.model = self.model.half() 164 elif cfg.bf16: 165 self.model = self.model.bfloat16() 166 self.model = self.model.cuda() 167 168 self.idf_dict: dict[int, float] 169 if cfg.idf and self.scorer._idf_dict is not None: 170 self.idf_dict = self.scorer._idf_dict 171 else: 172 self.idf_dict = defaultdict(lambda: 1.0) 173 if ( 174 sep_token_id := getattr(self.tokenizer, "sep_token_id", None) 175 ) is not None: 176 self.idf_dict[sep_token_id] = 0.0 177 if ( 178 cls_token_id := getattr(self.tokenizer, "cls_token_id", None) 179 ) is not None: 180 self.idf_dict[cls_token_id] = 0.0 181 182 @property 183 def device(self) -> torch.device: 184 """Returns the device of the model.""" 185 return self.model.device 186 187 @property 188 def embed_dim(self) -> int: 189 """Return the size of embedding dimension.""" 190 return self.model.config.hidden_size 191 192 def _tokenize(self, sentence: str) -> list[int]: 193 """Tokenize a sentence and encode it to the token IDs. 194 195 Args: 196 sentence (str): An input sentence. 197 198 Returns: 199 list[int]: The token IDs. 200 """ 201 tokenizer_kwargs = {} 202 if isinstance(self.tokenizer, (GPT2Tokenizer, RobertaTokenizer)): 203 tokenizer_kwargs["add_prefix_space"] = True 204 205 return self.tokenizer.encode( 206 sentence, 207 add_special_tokens=True, 208 max_length=self.tokenizer.model_max_length, 209 truncation=True, 210 **tokenizer_kwargs, 211 ) 212 213 def _collate(self, batch_ids: list[EncodedInput]) -> BatchEncoding: 214 """Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It 215 adds special tokens, truncates sequences if overflowing while taking into account the special tokens and 216 manages a moving window (with user defined stride) for overflowing tokens 217 218 Args: 219 batch_ids_pairs (list[EncodedInputPair]): List of tokenized input ids. 220 221 Returns: 222 BatchEncoding: A mini-batch. 223 """ 224 batch = {} 225 for ids in batch_ids: 226 example = self.tokenizer.prepare_for_model( 227 ids, 228 add_special_tokens=False, 229 padding=False, 230 pad_to_multiple_of=None, 231 return_attention_mask=False, 232 return_tensors=None, 233 ) 234 235 for key, value in example.items(): 236 if key not in batch: 237 batch[key] = [] 238 batch[key].append(value) 239 240 return self.tokenizer.pad( 241 batch, 242 padding=True, 243 max_length=self.tokenizer.model_max_length, 244 return_tensors="pt", 245 ) 246
[docs] 247 def encode(self, sentences: list[str]) -> MetricBERTScore.Cache: 248 """Encode the given sentences into their intermediate representations. 249 250 Args: 251 sentences (list[str]): Input sentences. 252 253 Returns: 254 Tensor: Intermediate representations of shape `(N, D)` where `N` is the 255 number of hypotheses and `D` is a size of the embedding dimension. 256 """ 257 sequences = [self._tokenize(sentence) for sentence in sentences] 258 embeddings = [] 259 for i in range(0, len(sentences), self.cfg.batch_size): 260 batch = self._collate(sequences[i : i + self.cfg.batch_size]) 261 attention_mask = batch.attention_mask.bool() 262 embs = self.model(**batch.to(self.device))[0].cpu() 263 for j in range(len(embs)): 264 embeddings.append(embs[j, attention_mask[j]]) 265 idf_weights = [ 266 torch.Tensor([self.idf_dict.get(token, 1.0) for token in seq]) 267 for seq in sequences 268 ] 269 270 return self.Cache(embeddings, idf_weights)
271 272 def _choose_output_score(self, triplet: tuple[Tensor, Tensor, Tensor]) -> Tensor: 273 """Choose the output score from the triplet of precision, recall, and f1 scores. 274 275 Args: 276 triplet (tuple[Tensor, Tensor, Tensor]): A triplet of precision, recall, and f1 scores. 277 278 Returns: 279 Tensor: Output score. 280 """ 281 return triplet[self.cfg.score_type] 282
[docs] 283 def pad_sequence(self, tensors: list[Tensor]) -> Tensor: 284 match tensors[0].dtype: 285 case torch.bool: 286 padding_value = False 287 case torch.float32: 288 padding_value = torch.finfo(torch.float32).eps 289 case torch.float16: 290 padding_value = torch.finfo(torch.float16).eps 291 case torch.bfloat16: 292 padding_value = torch.finfo(torch.bfloat16).eps 293 case _: 294 padding_value = 0.0 295 296 return torch.nn.utils.rnn.pad_sequence( 297 tensors, batch_first=True, padding_value=padding_value 298 ).to(self.device)
299
[docs] 300 def out_proj( 301 self, 302 hypotheses_ir: Cache, 303 references_ir: Cache, 304 sources_ir: Optional[Cache] = None, 305 ) -> Tensor: 306 """Forward the output projection layer. 307 308 Args: 309 hypotheses_ir (Cache): N intermediate representations of hypotheses. 310 references_ir (Cache): N intermediate representations of references. 311 sources_ir (Cache, optional): N intermediate representations of sources. 312 313 Returns: 314 Tensor: N scores. 315 """ 316 317 hypotheses_embeddings = self.pad_sequence(hypotheses_ir.embeddings) 318 references_embeddings = self.pad_sequence(references_ir.embeddings) 319 hypotheses_token_masks = self.pad_sequence( 320 [torch.BoolTensor([True] * len(emb)) for emb in hypotheses_ir.embeddings] 321 ) 322 references_token_masks = self.pad_sequence( 323 [torch.BoolTensor([True] * len(emb)) for emb in references_ir.embeddings] 324 ) 325 hypotheses_idf_weights = self.pad_sequence(hypotheses_ir.idf_weights) 326 references_idf_weights = self.pad_sequence(references_ir.idf_weights) 327 328 scores = self._choose_output_score( 329 bert_score.utils.greedy_cos_idf( 330 references_embeddings, 331 references_token_masks, 332 references_idf_weights, 333 hypotheses_embeddings, 334 hypotheses_token_masks, 335 hypotheses_idf_weights, 336 all_layers=False, 337 ) 338 ) 339 if self.cfg.rescale_with_baseline: 340 scores = (scores - self.scorer.baseline_vals) / ( 341 1 - self.scorer.baseline_vals 342 ) 343 return scores.view(len(hypotheses_embeddings))
344
[docs] 345 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor: 346 """Calculate the scores of the given hypothesis. 347 348 Args: 349 hypotheses (list[str]): N hypotheses. 350 references (list[str]): N references. 351 352 Returns: 353 Tensor: The N scores of the given hypotheses. 354 """ 355 return super().scores(hypotheses, references)
356
[docs] 357 def pairwise_scores( 358 self, hypotheses: list[str], references: list[str], *_, **__ 359 ) -> Tensor: 360 """Calculate the pairwise scores. 361 362 Args: 363 hypotheses (list[str]): Hypotheses. 364 references (list[str]): References. 365 366 Returns: 367 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 368 of hypotheses and `R` is the number of references. 369 """ 370 return super().pairwise_scores(hypotheses, references)
371
[docs] 372 def corpus_score( 373 self, 374 hypotheses: list[str], 375 references_lists: list[list[str]], 376 sources: Optional[list[str]] = None, 377 ) -> float: 378 """Calculate the corpus-level score. 379 380 Args: 381 hypotheses (list[str]): Hypotheses. 382 references_lists (list[list[str]]): Lists of references. 383 sources (list[str], optional): Sources. 384 385 Returns: 386 float: The corpus score. 387 """ 388 scores: list[Tensor] = [] 389 for references in references_lists: 390 scores.append(self.scores(hypotheses, references).cpu().float()) 391 return torch.cat(scores).mean().item()