1from __future__ import annotations
2
3import copy
4import enum
5import itertools
6import os
7import warnings
8from dataclasses import dataclass
9from typing import Optional
10
11import torch
12import torch.nn as nn
13import transformers
14from torch import Tensor
15from transformers import AutoTokenizer
16from transformers.modeling_outputs import BaseModelOutput, ModelOutput
17from transformers.models.mt5.modeling_mt5 import (
18 __HEAD_MASK_WARNING_MSG,
19 MT5Config,
20 MT5PreTrainedModel,
21 MT5Stack,
22)
23from transformers.tokenization_utils import BatchEncoding, EncodedInput
24
25from mbrs import timer
26
27from . import Metric, register
28
29transformers.logging.set_verbosity_error()
30
31
[docs]
32@dataclass
33class MT5ForRegressionOutput(ModelOutput):
34 loss: Optional[torch.Tensor] = None
35 predictions: Optional[torch.Tensor] = None
36
37
[docs]
38class MT5ForRegression(MT5PreTrainedModel):
39 """MT5 model for regression.
40
41 This implementation is copied from https://github.com/google-research/metricx
42 """
43
44 def __init__(self, config: MT5Config):
45 super().__init__(config)
46 self.model_dim = config.d_model
47
48 self.shared = nn.Embedding(config.vocab_size, config.d_model)
49
50 encoder_config = copy.deepcopy(config)
51 encoder_config.is_decoder = False
52 encoder_config.use_cache = False
53 encoder_config.is_encoder_decoder = False
54 self.encoder = MT5Stack(encoder_config, self.shared)
55
56 decoder_config = copy.deepcopy(config)
57 decoder_config.is_decoder = True
58 decoder_config.is_encoder_decoder = True
59 decoder_config.num_layers = config.num_decoder_layers
60 self.decoder = MT5Stack(decoder_config, self.shared)
61
62 self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
63
64 # Initialize weights and apply final processing
65 self.post_init()
66
67 # Model parallel
68 self.model_parallel = False
69 self.device_map = None
70
[docs]
71 def forward(
72 self,
73 input_ids: Optional[torch.LongTensor] = None,
74 attention_mask: Optional[torch.FloatTensor] = None,
75 decoder_attention_mask: Optional[torch.BoolTensor] = None,
76 head_mask: Optional[torch.FloatTensor] = None,
77 decoder_head_mask: Optional[torch.FloatTensor] = None,
78 cross_attn_head_mask: Optional[torch.Tensor] = None,
79 encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
80 past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None,
81 inputs_embeds: Optional[torch.FloatTensor] = None,
82 decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
83 labels: Optional[torch.FloatTensor] = None,
84 use_cache: Optional[bool] = None,
85 output_attentions: Optional[bool] = None,
86 output_hidden_states: Optional[bool] = None,
87 return_dict: Optional[bool] = None,
88 ) -> tuple[torch.Tensor] | MT5ForRegressionOutput:
89 use_cache = use_cache if use_cache is not None else self.config.use_cache
90 return_dict = (
91 return_dict if return_dict is not None else self.config.use_return_dict
92 )
93
94 # FutureWarning: head_mask was separated into two input args - head_mask,
95 # decoder_head_mask
96 if head_mask is not None and decoder_head_mask is None:
97 if self.config.num_layers == self.config.num_decoder_layers:
98 warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
99 decoder_head_mask = head_mask
100
101 # Encode if needed (training, first prediction pass)
102 if encoder_outputs is None:
103 # Convert encoder inputs in embeddings if needed
104 encoder_outputs = self.encoder(
105 input_ids=input_ids,
106 attention_mask=attention_mask,
107 inputs_embeds=inputs_embeds,
108 head_mask=head_mask,
109 output_attentions=output_attentions,
110 output_hidden_states=output_hidden_states,
111 return_dict=return_dict,
112 )
113 elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
114 encoder_outputs = BaseModelOutput(
115 last_hidden_state=encoder_outputs[0],
116 hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
117 attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
118 )
119
120 hidden_states = encoder_outputs[0]
121
122 if self.model_parallel:
123 torch.cuda.set_device(self.decoder.first_device)
124
125 # Create 1 step of dummy input for the decoder.
126 batch_size = input_ids.size(0)
127 decoder_input_ids = torch.LongTensor([0]).repeat(batch_size).reshape(-1, 1)
128 if torch.cuda.is_available():
129 decoder_input_ids = decoder_input_ids.to(torch.device("cuda"))
130
131 # Set device for model parallelism
132 if self.model_parallel:
133 torch.cuda.set_device(self.decoder.first_device)
134 hidden_states = hidden_states.to(self.decoder.first_device)
135 if decoder_input_ids is not None:
136 decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
137 if attention_mask is not None:
138 attention_mask = attention_mask.to(self.decoder.first_device)
139 if decoder_attention_mask is not None:
140 decoder_attention_mask = decoder_attention_mask.to(
141 self.decoder.first_device
142 )
143
144 # Decode
145 decoder_outputs = self.decoder(
146 input_ids=decoder_input_ids,
147 attention_mask=decoder_attention_mask,
148 inputs_embeds=decoder_inputs_embeds,
149 past_key_values=past_key_values,
150 encoder_hidden_states=hidden_states,
151 encoder_attention_mask=attention_mask,
152 head_mask=decoder_head_mask,
153 cross_attn_head_mask=cross_attn_head_mask,
154 use_cache=use_cache,
155 output_attentions=output_attentions,
156 output_hidden_states=output_hidden_states,
157 return_dict=return_dict,
158 )
159
160 sequence_output = decoder_outputs[0]
161
162 # Set device for model parallelism
163 if self.model_parallel:
164 torch.cuda.set_device(self.encoder.first_device)
165 self.lm_head = self.lm_head.to(self.encoder.first_device)
166 sequence_output = sequence_output.to(self.lm_head.weight.device)
167
168 if self.config.tie_word_embeddings:
169 # Rescale output before projecting on vocab
170 # See
171 # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
172 sequence_output = sequence_output * (self.model_dim**-0.5)
173
174 lm_logits = self.lm_head(sequence_output)
175
176 # 250089 = <extra_id_10>
177 predictions = lm_logits[:, 0, 250089]
178
179 # Clip to 0 to 25
180 predictions = torch.clamp(predictions, 0, 25)
181
182 loss = None
183 if labels is not None:
184 loss_fct = nn.MSELoss()
185 # move labels to correct device to enable PP
186 labels = labels.to(predictions.device)
187 loss = loss_fct(predictions.view(-1), labels.view(-1))
188
189 return MT5ForRegressionOutput(loss=loss, predictions=predictions)
190
191
[docs]
192@register("metricx")
193class MetricMetricX(Metric):
194 """MetricX metric class.
195
196 References:
197 - MetricX-23: https://aclanthology.org/2023.wmt-1.63
198 - MetricX-24: https://aclanthology.org/2024.wmt-1.35
199
200 Available checkpoints:
201
202 - google/metricx-24-hybrid-xxl-v2p6
203 - google/metricx-24-hybrid-xl-v2p6
204 - google/metricx-24-hybrid-large-v2p6
205 - google/metricx-23-xxl-v2p0
206 - google/metricx-23-xl-v2p0
207 - google/metricx-23-large-v2p0
208 - google/metricx-23-qe-xxl-v2p0
209 - google/metricx-23-qe-xl-v2p0
210 - google/metricx-23-qe-large-v2p0
211 """
212
213 HIGHER_IS_BETTER: bool = False
214
215 scorer: MT5ForRegression
216
[docs]
217 @dataclass
218 class Config(Metric.Config):
219 """MetricX metric configuration.
220
221 - model (str): Model name or path.
222 - batch_size (int): Batch size.
223 - fp16 (bool): Use float16 for the forward computation.
224 - bf16 (bool): Use bfloat16 for the forward computation.
225 - cpu (bool): Use CPU for the forward computation.
226 """
227
228 model: str = "google/metricx-24-hybrid-xxl-v2p6"
229 batch_size: int = 8
230 fp16: bool = False
231 bf16: bool = False
232 cpu: bool = False
233
[docs]
234 class MetricXVersion(str, enum.Enum):
235 metricx_24 = "metricx_24"
236 metricx_23 = "metricx_23"
237
238 METRICX_VERSION_MAP = {
239 "google/metricx-24-hybrid-xxl-v2p6": MetricXVersion.metricx_24,
240 "google/metricx-24-hybrid-xl-v2p6": MetricXVersion.metricx_24,
241 "google/metricx-24-hybrid-large-v2p6": MetricXVersion.metricx_24,
242 "google/metricx-23-xxl-v2p0": MetricXVersion.metricx_23,
243 "google/metricx-23-xl-v2p0": MetricXVersion.metricx_23,
244 "google/metricx-23-large-v2p0": MetricXVersion.metricx_23,
245 "google/metricx-23-qe-xxl-v2p0": MetricXVersion.metricx_23,
246 "google/metricx-23-qe-xl-v2p0": MetricXVersion.metricx_23,
247 "google/metricx-23-qe-large-v2p0": MetricXVersion.metricx_23,
248 }
249 METRICX_INPUT_LENGTH_MAP = {
250 MetricXVersion.metricx_24: 1536,
251 MetricXVersion.metricx_23: 1024,
252 }
253 METRICX23_QE_MODELS = {
254 "google/metricx-23-qe-xxl-v2p0",
255 "google/metricx-23-qe-xl-v2p0",
256 "google/metricx-23-qe-large-v2p0",
257 }
258
264
265 METRICX_INPUT_PREFIX_MAP = {
266 MetricXVersion.metricx_24: InputPrefix(
267 " candidate: ", " reference: ", "source: "
268 ),
269 MetricXVersion.metricx_23: InputPrefix(
270 "candidate: ", " reference: ", " source: "
271 ),
272 }
273
274 def __init__(self, cfg: MetricMetricX.Config):
275 super().__init__(cfg)
276 self.scorer = MT5ForRegression.from_pretrained(cfg.model)
277 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
278 self.tokenizer = AutoTokenizer.from_pretrained(
279 "google/mt5-xl", legacy=False, use_fast=False
280 )
281 self.metricx_version = self.METRICX_VERSION_MAP[cfg.model]
282 self.max_length = self.METRICX_INPUT_LENGTH_MAP[self.metricx_version]
283 self.input_prefix = self.METRICX_INPUT_PREFIX_MAP[self.metricx_version]
284
285 self.scorer.eval()
286 for param in self.scorer.parameters():
287 param.requires_grad = False
288
289 if not cfg.cpu and torch.cuda.is_available():
290 if cfg.fp16:
291 self.scorer = self.scorer.half()
292 elif cfg.bf16:
293 self.scorer = self.scorer.bfloat16()
294 self.scorer = self.scorer.cuda()
295
296 @property
297 def device(self) -> torch.device:
298 """Returns the device of the model."""
299 return self.scorer.device
300
301 def _encode_hypothesis(self, hypothesis: str) -> list[int]:
302 """Encode a hypothesis.
303
304 Args:
305 hypothesis (str): A hypothesis.
306
307 Returns:
308 list[int]: Token IDs of a hypothesis.
309 """
310 return self.tokenizer.encode(
311 self.input_prefix.hypothesis + hypothesis, add_special_tokens=False
312 )
313
314 def _encode_reference(self, reference: str) -> list[int]:
315 """Encode a reference.
316
317 Args:
318 reference (str): A reference.
319
320 Returns:
321 list[int]: Token IDs of a reference.
322 """
323 return self.tokenizer.encode(
324 self.input_prefix.reference + reference, add_special_tokens=False
325 )
326
327 def _encode_source(self, source: str) -> list[int]:
328 """Encode a source.
329
330 Args:
331 source (str): A source.
332
333 Returns:
334 list[int]: Token IDs of a source.
335 """
336 return self.tokenizer.encode(
337 self.input_prefix.source + source, add_special_tokens=False
338 )
339
340 def _concatenate_inputs(
341 self,
342 hypothesis_ids: list[int],
343 reference_ids: Optional[list[int]] = None,
344 source_ids: Optional[list[int]] = None,
345 ) -> list[int]:
346 """Prepare a model input for MetricX.
347
348 Args:
349 hypothesis_ids (str): Hypothesis token IDs.
350 reference_ids (str, optional): Reference token IDs.
351 source_ids (str, optional): Source token IDs.
352
353 Returns:
354 str: Input string.
355 """
356 input_ids: list[int] = []
357 match self.metricx_version:
358 case self.MetricXVersion.metricx_24:
359 if source_ids is None:
360 raise ValueError("MetricX-24 requires the source text.")
361 input_ids += source_ids + hypothesis_ids
362 if reference_ids is not None:
363 input_ids += reference_ids
364 case self.MetricXVersion.metricx_23:
365 input_ids += hypothesis_ids
366 if self.cfg.model in self.METRICX23_QE_MODELS:
367 if source_ids is None:
368 raise ValueError("MetricX-23-QE requires the source text.")
369 input_ids += source_ids
370 else:
371 if reference_ids is None:
372 raise ValueError("MetricX-23 requires the reference text.")
373 input_ids += reference_ids
374 return input_ids
375
376 def _collate(self, batch_ids: list[EncodedInput]) -> BatchEncoding:
377 """Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
378 adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
379 manages a moving window (with user defined stride) for overflowing tokens
380
381 Args:
382 batch_ids (list[EncodedInput]): List of tokenized input ids.
383 """
384
385 batch = {}
386 for input_ids in batch_ids:
387 example = self.tokenizer.prepare_for_model(
388 input_ids,
389 add_special_tokens=False,
390 padding=False,
391 truncation=True,
392 max_length=self.max_length,
393 pad_to_multiple_of=None,
394 return_attention_mask=False,
395 return_tensors=None,
396 )
397
398 for key, value in example.items():
399 if key not in batch:
400 batch[key] = []
401 batch[key].append(value)
402
403 batch = self.tokenizer.pad(batch, padding=True, return_tensors="pt")
404 return batch
405
[docs]
406 def score(
407 self,
408 hypothesis: str,
409 reference: Optional[str] = None,
410 source: Optional[str] = None,
411 ) -> float:
412 """Calculate the score of the given hypothesis.
413
414 Args:
415 hypothesis (str): A hypothesis.
416 reference (str, optional): A reference.
417 source (str, optional): A source.
418
419 Returns:
420 float: The score of the given hypothesis.
421 """
422
423 batch = self._collate(
424 [
425 self._concatenate_inputs(
426 self._encode_hypothesis(hypothesis),
427 self._encode_reference(reference)
428 if reference is not None
429 else None,
430 self._encode_source(source) if source is not None else None,
431 )
432 ]
433 ).to(self.device)
434 model_output = self.scorer(**batch)
435 return model_output.predictions.flatten().tolist()[0]
436
[docs]
437 def scores(
438 self,
439 hypotheses: list[str],
440 references: Optional[list[str]] = None,
441 sources: Optional[list[str]] = None,
442 ) -> Tensor:
443 """Calculate the scores of the given hypothesis.
444
445 Args:
446 hypotheses (list[str]): N hypotheses.
447 references (list[str], optional): N references.
448 sources (list[str], optional): N sources.
449
450 Returns:
451 Tensor: The N scores of the given hypotheses.
452 """
453 examples: list[list[int]] = []
454 for i, hyp in enumerate(hypotheses):
455 examples.append(
456 self._concatenate_inputs(
457 self._encode_hypothesis(hyp),
458 self._encode_reference(references[i])
459 if references is not None
460 else None,
461 self._encode_source(sources[i]) if sources is not None else None,
462 )
463 )
464
465 scores = []
466 with timer.measure("score") as t:
467 t.set_delta_ncalls(len(examples))
468 for i in range(0, len(examples), self.cfg.batch_size):
469 batch = self._collate(examples[i : i + self.cfg.batch_size]).to(
470 self.device
471 )
472 model_output = self.scorer(**batch)
473 scores.append(model_output.predictions.flatten())
474 return torch.cat(scores).view(len(hypotheses))
475
[docs]
476 def pairwise_scores(
477 self, hypotheses: list[str], references: list[str], source: Optional[str] = None
478 ) -> Tensor:
479 """Calculate the pairwise scores.
480
481 Args:
482 hypotheses (list[str]): Hypotheses.
483 references (list[str]): References.
484 source (str, optional): A source.
485
486 Returns:
487 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
488 of hypotheses and `R` is the number of references.
489 """
490 scores = []
491 hypotheses_ids = [self._encode_hypothesis(hyp) for hyp in hypotheses]
492 references_ids = [self._encode_reference(ref) for ref in references]
493 source_ids = self._encode_source(source) if source is not None else None
494 pairwise_iter = itertools.product(hypotheses_ids, references_ids)
495
496 while batch := list(itertools.islice(pairwise_iter, self.cfg.batch_size)):
497 with timer.measure("score") as t:
498 t.set_delta_ncalls(len(batch))
499 batch = self._collate(
500 [
501 self._concatenate_inputs(hyp_ids, ref_ids, source_ids)
502 for hyp_ids, ref_ids in batch
503 ]
504 ).to(self.device)
505 model_output = self.scorer(**batch)
506 scores.append(model_output.predictions.flatten())
507 return torch.cat(scores).view(len(hypotheses), len(references))
508
[docs]
509 def corpus_score(
510 self,
511 hypotheses: list[str],
512 references_lists: Optional[list[list[str]]] = None,
513 sources: Optional[list[str]] = None,
514 ) -> float:
515 """Calculate the corpus-level score.
516
517 Args:
518 hypotheses (list[str]): Hypotheses.
519 references (list[list[str]], optional): Lists of references.
520 sources (list[str], optional): Sources.
521
522 Returns:
523 float: The corpus score.
524 """
525 scores: list[Tensor] = []
526 if references_lists is None:
527 if sources is None:
528 raise ValueError(
529 "`sources` must be given when `references_lists` is None."
530 )
531
532 for i in range(0, len(hypotheses), self.cfg.batch_size):
533 scores.append(
534 self.scores(
535 hypotheses[i : i + self.cfg.batch_size],
536 None,
537 sources[i : i + self.cfg.batch_size],
538 )
539 .float()
540 .cpu()
541 )
542 else:
543 for references in references_lists:
544 for i in range(0, len(hypotheses), self.cfg.batch_size):
545 scores.append(
546 self.scores(
547 hypotheses[i : i + self.cfg.batch_size],
548 references[i : i + self.cfg.batch_size],
549 sources[i : i + self.cfg.batch_size]
550 if sources is not None
551 else None,
552 )
553 .float()
554 .cpu()
555 )
556 return torch.cat(scores).mean().item()