1from __future__ import annotations
2
3import enum
4from collections import defaultdict
5from dataclasses import dataclass
6from typing import Optional, Sequence
7
8import bert_score
9import bert_score.utils
10import torch
11import transformers
12from bert_score import BERTScorer
13from simple_parsing.helpers.fields import choice
14from torch import Tensor
15from transformers.models.gpt2 import GPT2Tokenizer
16from transformers.models.roberta import RobertaTokenizer
17from transformers.tokenization_utils import (
18 BatchEncoding,
19 EncodedInput,
20 PreTrainedTokenizerBase,
21)
22
23from mbrs.metrics.base import MetricCacheable
24
25from . import Metric, register
26
27transformers.logging.set_verbosity_error()
28
29
[docs]
30class BERTScoreScoreType(int, enum.Enum):
31 precision = 0
32 recall = 1
33 f1 = 2
34
35
[docs]
36@register("bertscore")
37class MetricBERTScore(MetricCacheable):
38 """BERTScore metric class."""
39
[docs]
40 @dataclass
41 class Config(Metric.Config):
42 """BERTScore metric configuration.
43
44 - score_type (BERTScoreScoreType): The output score type, i.e.,
45 precision, recall, or f1.
46 - model_type (str): Contexual embedding model specification, default using the
47 suggested model for the target langauge; has to specify at least one of
48 `model_type` or `lang`.
49 - num_layers (int): The layer of representation to use. Default using the number
50 of layer tuned on WMT16 correlation data.
51 - idf (bool): A booling to specify whether to use idf or not. (This should be
52 True even if `idf_sents` is given.)
53 - idf_sents (list[str]): List of sentences used to compute the idf weights.
54 - batch_size (int): Bert score processing batch size
55 - nthreads (int): Number of threads.
56 - lang (str): Language of the sentences; has to specify at least one of
57 `model_type` or `lang`. `lang` needs to be specified when
58 `rescale_with_baseline` is True.
59 - rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
60 - baseline_path (str): Customized baseline file.
61 - use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer.
62 - fp16 (bool): Use float16 for the forward computation.
63 - bf16 (bool): Use bfloat16 for the forward computation.
64 - cpu (bool): Use CPU for the forward computation.
65 """
66
67 score_type: BERTScoreScoreType = choice(
68 BERTScoreScoreType, default=BERTScoreScoreType.f1
69 )
70 model_type: Optional[str] = None
71 num_layers: Optional[int] = None
72 batch_size: int = 64
73 nthreads: int = 4
74 idf: bool = False
75 idf_sents: Optional[list[str]] = None
76 lang: Optional[str] = None
77 rescale_with_baseline: bool = False
78 baseline_path: Optional[str] = None
79 use_fast_tokenizer: bool = False
80 fp16: bool = False
81 bf16: bool = False
82 cpu: bool = False
83
[docs]
84 @dataclass
85 class Cache(MetricCacheable.Cache):
86 """Intermediate representations of sentences.
87
88 - embeddings (list[Tensor]): A list of token embeddings of shape `(T, D)`,
89 where `T` is the length of sequence, and `D` is a size of the embedding.
90 - idf_weights (list[Tensor]): A list of IDF weights of shape `(T,)`.
91 """
92
93 embeddings: list[Tensor]
94 idf_weights: list[Tensor]
95
96 def __len__(self) -> int:
97 """Return the length of cache."""
98 return len(self.embeddings)
99
100 def __getitem__(
101 self, key: int | Sequence[int] | slice | Tensor
102 ) -> MetricBERTScore.Cache:
103 """Get the items."""
104 if isinstance(key, Tensor):
105 dtype = key.dtype
106 key = key.tolist()
107 if dtype == torch.bool:
108 return type(self)(
109 [self.embeddings[k] for k in key if k],
110 [self.idf_weights[k] for k in key if k],
111 )
112 return type(self)(
113 [self.embeddings[k] for k in key],
114 [self.idf_weights[k] for k in key],
115 )
116 elif isinstance(key, Sequence):
117 return type(self)(
118 [self.embeddings[k] for k in key],
119 [self.idf_weights[k] for k in key],
120 )
121 elif isinstance(key, slice):
122 return type(self)(self.embeddings[key], self.idf_weights[key])
123 else:
124 return type(self)([self.embeddings[key]], [self.idf_weights[key]])
125
[docs]
126 def repeat(self, n: int) -> MetricBERTScore.Cache:
127 """Repeat the representations by n times.
128
129 Args:
130 n (int): The number of repetition.
131
132 Returns:
133 Cache: The repeated cache.
134 """
135 return type(self)(self.embeddings * n, self.idf_weights * n)
136
137 cfg: MetricBERTScore.Config
138
139 def __init__(self, cfg: MetricBERTScore.Config):
140 super().__init__(cfg)
141 self.scorer: BERTScorer = BERTScorer(
142 model_type=cfg.model_type,
143 num_layers=cfg.num_layers,
144 batch_size=cfg.batch_size,
145 nthreads=cfg.nthreads,
146 all_layers=False,
147 idf=cfg.idf,
148 idf_sents=cfg.idf_sents,
149 device="cpu" if cfg.cpu else None,
150 lang=cfg.lang,
151 rescale_with_baseline=cfg.rescale_with_baseline,
152 baseline_path=cfg.baseline_path,
153 use_fast_tokenizer=cfg.use_fast_tokenizer,
154 )
155 self.tokenizer: PreTrainedTokenizerBase = self.scorer._tokenizer
156 self.model = self.scorer._model
157 self.model.eval()
158 for param in self.model.parameters():
159 param.requires_grad = False
160
161 if not cfg.cpu and torch.cuda.is_available():
162 if cfg.fp16:
163 self.model = self.model.half()
164 elif cfg.bf16:
165 self.model = self.model.bfloat16()
166 self.model = self.model.cuda()
167
168 self.idf_dict: dict[int, float]
169 if cfg.idf and self.scorer._idf_dict is not None:
170 self.idf_dict = self.scorer._idf_dict
171 else:
172 self.idf_dict = defaultdict(lambda: 1.0)
173 if (
174 sep_token_id := getattr(self.tokenizer, "sep_token_id", None)
175 ) is not None:
176 self.idf_dict[sep_token_id] = 0.0
177 if (
178 cls_token_id := getattr(self.tokenizer, "cls_token_id", None)
179 ) is not None:
180 self.idf_dict[cls_token_id] = 0.0
181
182 @property
183 def device(self) -> torch.device:
184 """Returns the device of the model."""
185 return self.model.device
186
187 @property
188 def embed_dim(self) -> int:
189 """Return the size of embedding dimension."""
190 return self.model.config.hidden_size
191
192 def _tokenize(self, sentence: str) -> list[int]:
193 """Tokenize a sentence and encode it to the token IDs.
194
195 Args:
196 sentence (str): An input sentence.
197
198 Returns:
199 list[int]: The token IDs.
200 """
201 tokenizer_kwargs = {}
202 if isinstance(self.tokenizer, (GPT2Tokenizer, RobertaTokenizer)):
203 tokenizer_kwargs["add_prefix_space"] = True
204
205 return self.tokenizer.encode(
206 sentence,
207 add_special_tokens=True,
208 max_length=self.tokenizer.model_max_length,
209 truncation=True,
210 **tokenizer_kwargs,
211 )
212
213 def _collate(self, batch_ids: list[EncodedInput]) -> BatchEncoding:
214 """Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
215 adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
216 manages a moving window (with user defined stride) for overflowing tokens
217
218 Args:
219 batch_ids_pairs (list[EncodedInputPair]): List of tokenized input ids.
220
221 Returns:
222 BatchEncoding: A mini-batch.
223 """
224 batch = {}
225 for ids in batch_ids:
226 example = self.tokenizer.prepare_for_model(
227 ids,
228 add_special_tokens=False,
229 padding=False,
230 pad_to_multiple_of=None,
231 return_attention_mask=False,
232 return_tensors=None,
233 )
234
235 for key, value in example.items():
236 if key not in batch:
237 batch[key] = []
238 batch[key].append(value)
239
240 return self.tokenizer.pad(
241 batch,
242 padding=True,
243 max_length=self.tokenizer.model_max_length,
244 return_tensors="pt",
245 )
246
[docs]
247 def encode(self, sentences: list[str]) -> MetricBERTScore.Cache:
248 """Encode the given sentences into their intermediate representations.
249
250 Args:
251 sentences (list[str]): Input sentences.
252
253 Returns:
254 Tensor: Intermediate representations of shape `(N, D)` where `N` is the
255 number of hypotheses and `D` is a size of the embedding dimension.
256 """
257 sequences = [self._tokenize(sentence) for sentence in sentences]
258 embeddings = []
259 for i in range(0, len(sentences), self.cfg.batch_size):
260 batch = self._collate(sequences[i : i + self.cfg.batch_size])
261 attention_mask = batch.attention_mask.bool()
262 embs = self.model(**batch.to(self.device))[0].cpu()
263 for j in range(len(embs)):
264 embeddings.append(embs[j, attention_mask[j]])
265 idf_weights = [
266 torch.Tensor([self.idf_dict.get(token, 1.0) for token in seq])
267 for seq in sequences
268 ]
269
270 return self.Cache(embeddings, idf_weights)
271
272 def _choose_output_score(self, triplet: tuple[Tensor, Tensor, Tensor]) -> Tensor:
273 """Choose the output score from the triplet of precision, recall, and f1 scores.
274
275 Args:
276 triplet (tuple[Tensor, Tensor, Tensor]): A triplet of precision, recall, and f1 scores.
277
278 Returns:
279 Tensor: Output score.
280 """
281 return triplet[self.cfg.score_type]
282
[docs]
283 def pad_sequence(self, tensors: list[Tensor]) -> Tensor:
284 match tensors[0].dtype:
285 case torch.bool:
286 padding_value = False
287 case torch.float32:
288 padding_value = torch.finfo(torch.float32).eps
289 case torch.float16:
290 padding_value = torch.finfo(torch.float16).eps
291 case torch.bfloat16:
292 padding_value = torch.finfo(torch.bfloat16).eps
293 case _:
294 padding_value = 0.0
295
296 return torch.nn.utils.rnn.pad_sequence(
297 tensors, batch_first=True, padding_value=padding_value
298 ).to(self.device)
299
[docs]
300 def out_proj(
301 self,
302 hypotheses_ir: Cache,
303 references_ir: Cache,
304 sources_ir: Optional[Cache] = None,
305 ) -> Tensor:
306 """Forward the output projection layer.
307
308 Args:
309 hypotheses_ir (Cache): N intermediate representations of hypotheses.
310 references_ir (Cache): N intermediate representations of references.
311 sources_ir (Cache, optional): N intermediate representations of sources.
312
313 Returns:
314 Tensor: N scores.
315 """
316
317 hypotheses_embeddings = self.pad_sequence(hypotheses_ir.embeddings)
318 references_embeddings = self.pad_sequence(references_ir.embeddings)
319 hypotheses_token_masks = self.pad_sequence(
320 [torch.BoolTensor([True] * len(emb)) for emb in hypotheses_ir.embeddings]
321 )
322 references_token_masks = self.pad_sequence(
323 [torch.BoolTensor([True] * len(emb)) for emb in references_ir.embeddings]
324 )
325 hypotheses_idf_weights = self.pad_sequence(hypotheses_ir.idf_weights)
326 references_idf_weights = self.pad_sequence(references_ir.idf_weights)
327
328 scores = self._choose_output_score(
329 bert_score.utils.greedy_cos_idf(
330 references_embeddings,
331 references_token_masks,
332 references_idf_weights,
333 hypotheses_embeddings,
334 hypotheses_token_masks,
335 hypotheses_idf_weights,
336 all_layers=False,
337 )
338 )
339 if self.cfg.rescale_with_baseline:
340 scores = (scores - self.scorer.baseline_vals) / (
341 1 - self.scorer.baseline_vals
342 )
343 return scores.view(len(hypotheses_embeddings))
344
[docs]
345 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor:
346 """Calculate the scores of the given hypothesis.
347
348 Args:
349 hypotheses (list[str]): N hypotheses.
350 references (list[str]): N references.
351
352 Returns:
353 Tensor: The N scores of the given hypotheses.
354 """
355 return super().scores(hypotheses, references)
356
[docs]
357 def pairwise_scores(
358 self, hypotheses: list[str], references: list[str], *_, **__
359 ) -> Tensor:
360 """Calculate the pairwise scores.
361
362 Args:
363 hypotheses (list[str]): Hypotheses.
364 references (list[str]): References.
365
366 Returns:
367 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
368 of hypotheses and `R` is the number of references.
369 """
370 return super().pairwise_scores(hypotheses, references)
371
[docs]
372 def corpus_score(
373 self,
374 hypotheses: list[str],
375 references_lists: list[list[str]],
376 sources: Optional[list[str]] = None,
377 ) -> float:
378 """Calculate the corpus-level score.
379
380 Args:
381 hypotheses (list[str]): Hypotheses.
382 references_lists (list[list[str]]): Lists of references.
383 sources (list[str], optional): Sources.
384
385 Returns:
386 float: The corpus score.
387 """
388 scores: list[Tensor] = []
389 for references in references_lists:
390 scores.append(self.scores(hypotheses, references).cpu().float())
391 return torch.cat(scores).mean().item()