Source code for mbrs.metrics.metricx

  1from __future__ import annotations
  2
  3import copy
  4import enum
  5import itertools
  6import os
  7import warnings
  8from dataclasses import dataclass
  9from typing import Optional
 10
 11import torch
 12import torch.nn as nn
 13import transformers
 14from torch import Tensor
 15from transformers import AutoTokenizer
 16from transformers.modeling_outputs import BaseModelOutput, ModelOutput
 17from transformers.models.mt5.modeling_mt5 import (
 18    __HEAD_MASK_WARNING_MSG,
 19    MT5Config,
 20    MT5PreTrainedModel,
 21    MT5Stack,
 22)
 23from transformers.tokenization_utils import BatchEncoding, EncodedInput
 24
 25from mbrs import timer
 26
 27from . import Metric, register
 28
 29transformers.logging.set_verbosity_error()
 30
 31
[docs] 32@dataclass 33class MT5ForRegressionOutput(ModelOutput): 34 loss: Optional[torch.Tensor] = None 35 predictions: Optional[torch.Tensor] = None
36 37
[docs] 38class MT5ForRegression(MT5PreTrainedModel): 39 """MT5 model for regression. 40 41 This implementation is copied from https://github.com/google-research/metricx 42 """ 43 44 def __init__(self, config: MT5Config): 45 super().__init__(config) 46 self.model_dim = config.d_model 47 48 self.shared = nn.Embedding(config.vocab_size, config.d_model) 49 50 encoder_config = copy.deepcopy(config) 51 encoder_config.is_decoder = False 52 encoder_config.use_cache = False 53 encoder_config.is_encoder_decoder = False 54 self.encoder = MT5Stack(encoder_config, self.shared) 55 56 decoder_config = copy.deepcopy(config) 57 decoder_config.is_decoder = True 58 decoder_config.is_encoder_decoder = True 59 decoder_config.num_layers = config.num_decoder_layers 60 self.decoder = MT5Stack(decoder_config, self.shared) 61 62 self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 63 64 # Initialize weights and apply final processing 65 self.post_init() 66 67 # Model parallel 68 self.model_parallel = False 69 self.device_map = None 70
[docs] 71 def forward( 72 self, 73 input_ids: Optional[torch.LongTensor] = None, 74 attention_mask: Optional[torch.FloatTensor] = None, 75 decoder_attention_mask: Optional[torch.BoolTensor] = None, 76 head_mask: Optional[torch.FloatTensor] = None, 77 decoder_head_mask: Optional[torch.FloatTensor] = None, 78 cross_attn_head_mask: Optional[torch.Tensor] = None, 79 encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, 80 past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, 81 inputs_embeds: Optional[torch.FloatTensor] = None, 82 decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 83 labels: Optional[torch.FloatTensor] = None, 84 use_cache: Optional[bool] = None, 85 output_attentions: Optional[bool] = None, 86 output_hidden_states: Optional[bool] = None, 87 return_dict: Optional[bool] = None, 88 ) -> tuple[torch.Tensor] | MT5ForRegressionOutput: 89 use_cache = use_cache if use_cache is not None else self.config.use_cache 90 return_dict = ( 91 return_dict if return_dict is not None else self.config.use_return_dict 92 ) 93 94 # FutureWarning: head_mask was separated into two input args - head_mask, 95 # decoder_head_mask 96 if head_mask is not None and decoder_head_mask is None: 97 if self.config.num_layers == self.config.num_decoder_layers: 98 warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 99 decoder_head_mask = head_mask 100 101 # Encode if needed (training, first prediction pass) 102 if encoder_outputs is None: 103 # Convert encoder inputs in embeddings if needed 104 encoder_outputs = self.encoder( 105 input_ids=input_ids, 106 attention_mask=attention_mask, 107 inputs_embeds=inputs_embeds, 108 head_mask=head_mask, 109 output_attentions=output_attentions, 110 output_hidden_states=output_hidden_states, 111 return_dict=return_dict, 112 ) 113 elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 114 encoder_outputs = BaseModelOutput( 115 last_hidden_state=encoder_outputs[0], 116 hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 117 attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 118 ) 119 120 hidden_states = encoder_outputs[0] 121 122 if self.model_parallel: 123 torch.cuda.set_device(self.decoder.first_device) 124 125 # Create 1 step of dummy input for the decoder. 126 batch_size = input_ids.size(0) 127 decoder_input_ids = torch.LongTensor([0]).repeat(batch_size).reshape(-1, 1) 128 if torch.cuda.is_available(): 129 decoder_input_ids = decoder_input_ids.to(torch.device("cuda")) 130 131 # Set device for model parallelism 132 if self.model_parallel: 133 torch.cuda.set_device(self.decoder.first_device) 134 hidden_states = hidden_states.to(self.decoder.first_device) 135 if decoder_input_ids is not None: 136 decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 137 if attention_mask is not None: 138 attention_mask = attention_mask.to(self.decoder.first_device) 139 if decoder_attention_mask is not None: 140 decoder_attention_mask = decoder_attention_mask.to( 141 self.decoder.first_device 142 ) 143 144 # Decode 145 decoder_outputs = self.decoder( 146 input_ids=decoder_input_ids, 147 attention_mask=decoder_attention_mask, 148 inputs_embeds=decoder_inputs_embeds, 149 past_key_values=past_key_values, 150 encoder_hidden_states=hidden_states, 151 encoder_attention_mask=attention_mask, 152 head_mask=decoder_head_mask, 153 cross_attn_head_mask=cross_attn_head_mask, 154 use_cache=use_cache, 155 output_attentions=output_attentions, 156 output_hidden_states=output_hidden_states, 157 return_dict=return_dict, 158 ) 159 160 sequence_output = decoder_outputs[0] 161 162 # Set device for model parallelism 163 if self.model_parallel: 164 torch.cuda.set_device(self.encoder.first_device) 165 self.lm_head = self.lm_head.to(self.encoder.first_device) 166 sequence_output = sequence_output.to(self.lm_head.weight.device) 167 168 if self.config.tie_word_embeddings: 169 # Rescale output before projecting on vocab 170 # See 171 # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 172 sequence_output = sequence_output * (self.model_dim**-0.5) 173 174 lm_logits = self.lm_head(sequence_output) 175 176 # 250089 = <extra_id_10> 177 predictions = lm_logits[:, 0, 250089] 178 179 # Clip to 0 to 25 180 predictions = torch.clamp(predictions, 0, 25) 181 182 loss = None 183 if labels is not None: 184 loss_fct = nn.MSELoss() 185 # move labels to correct device to enable PP 186 labels = labels.to(predictions.device) 187 loss = loss_fct(predictions.view(-1), labels.view(-1)) 188 189 return MT5ForRegressionOutput(loss=loss, predictions=predictions)
190 191
[docs] 192@register("metricx") 193class MetricMetricX(Metric): 194 """MetricX metric class. 195 196 References: 197 - MetricX-23: https://aclanthology.org/2023.wmt-1.63 198 - MetricX-24: https://aclanthology.org/2024.wmt-1.35 199 200 Available checkpoints: 201 202 - google/metricx-24-hybrid-xxl-v2p6 203 - google/metricx-24-hybrid-xl-v2p6 204 - google/metricx-24-hybrid-large-v2p6 205 - google/metricx-23-xxl-v2p0 206 - google/metricx-23-xl-v2p0 207 - google/metricx-23-large-v2p0 208 - google/metricx-23-qe-xxl-v2p0 209 - google/metricx-23-qe-xl-v2p0 210 - google/metricx-23-qe-large-v2p0 211 """ 212 213 HIGHER_IS_BETTER: bool = False 214 215 scorer: MT5ForRegression 216
[docs] 217 @dataclass 218 class Config(Metric.Config): 219 """MetricX metric configuration. 220 221 - model (str): Model name or path. 222 - batch_size (int): Batch size. 223 - fp16 (bool): Use float16 for the forward computation. 224 - bf16 (bool): Use bfloat16 for the forward computation. 225 - cpu (bool): Use CPU for the forward computation. 226 """ 227 228 model: str = "google/metricx-24-hybrid-xxl-v2p6" 229 batch_size: int = 8 230 fp16: bool = False 231 bf16: bool = False 232 cpu: bool = False
233
[docs] 234 class MetricXVersion(str, enum.Enum): 235 metricx_24 = "metricx_24" 236 metricx_23 = "metricx_23"
237 238 METRICX_VERSION_MAP = { 239 "google/metricx-24-hybrid-xxl-v2p6": MetricXVersion.metricx_24, 240 "google/metricx-24-hybrid-xl-v2p6": MetricXVersion.metricx_24, 241 "google/metricx-24-hybrid-large-v2p6": MetricXVersion.metricx_24, 242 "google/metricx-23-xxl-v2p0": MetricXVersion.metricx_23, 243 "google/metricx-23-xl-v2p0": MetricXVersion.metricx_23, 244 "google/metricx-23-large-v2p0": MetricXVersion.metricx_23, 245 "google/metricx-23-qe-xxl-v2p0": MetricXVersion.metricx_23, 246 "google/metricx-23-qe-xl-v2p0": MetricXVersion.metricx_23, 247 "google/metricx-23-qe-large-v2p0": MetricXVersion.metricx_23, 248 } 249 METRICX_INPUT_LENGTH_MAP = { 250 MetricXVersion.metricx_24: 1536, 251 MetricXVersion.metricx_23: 1024, 252 } 253 METRICX23_QE_MODELS = { 254 "google/metricx-23-qe-xxl-v2p0", 255 "google/metricx-23-qe-xl-v2p0", 256 "google/metricx-23-qe-large-v2p0", 257 } 258
[docs] 259 @dataclass 260 class InputPrefix: 261 hypothesis: str 262 reference: str 263 source: str
264 265 METRICX_INPUT_PREFIX_MAP = { 266 MetricXVersion.metricx_24: InputPrefix( 267 " candidate: ", " reference: ", "source: " 268 ), 269 MetricXVersion.metricx_23: InputPrefix( 270 "candidate: ", " reference: ", " source: " 271 ), 272 } 273 274 def __init__(self, cfg: MetricMetricX.Config): 275 super().__init__(cfg) 276 self.scorer = MT5ForRegression.from_pretrained(cfg.model) 277 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 278 self.tokenizer = AutoTokenizer.from_pretrained( 279 "google/mt5-xl", legacy=False, use_fast=False 280 ) 281 self.metricx_version = self.METRICX_VERSION_MAP[cfg.model] 282 self.max_length = self.METRICX_INPUT_LENGTH_MAP[self.metricx_version] 283 self.input_prefix = self.METRICX_INPUT_PREFIX_MAP[self.metricx_version] 284 285 self.scorer.eval() 286 for param in self.scorer.parameters(): 287 param.requires_grad = False 288 289 if not cfg.cpu and torch.cuda.is_available(): 290 if cfg.fp16: 291 self.scorer = self.scorer.half() 292 elif cfg.bf16: 293 self.scorer = self.scorer.bfloat16() 294 self.scorer = self.scorer.cuda() 295 296 @property 297 def device(self) -> torch.device: 298 """Returns the device of the model.""" 299 return self.scorer.device 300 301 def _encode_hypothesis(self, hypothesis: str) -> list[int]: 302 """Encode a hypothesis. 303 304 Args: 305 hypothesis (str): A hypothesis. 306 307 Returns: 308 list[int]: Token IDs of a hypothesis. 309 """ 310 return self.tokenizer.encode( 311 self.input_prefix.hypothesis + hypothesis, add_special_tokens=False 312 ) 313 314 def _encode_reference(self, reference: str) -> list[int]: 315 """Encode a reference. 316 317 Args: 318 reference (str): A reference. 319 320 Returns: 321 list[int]: Token IDs of a reference. 322 """ 323 return self.tokenizer.encode( 324 self.input_prefix.reference + reference, add_special_tokens=False 325 ) 326 327 def _encode_source(self, source: str) -> list[int]: 328 """Encode a source. 329 330 Args: 331 source (str): A source. 332 333 Returns: 334 list[int]: Token IDs of a source. 335 """ 336 return self.tokenizer.encode( 337 self.input_prefix.source + source, add_special_tokens=False 338 ) 339 340 def _concatenate_inputs( 341 self, 342 hypothesis_ids: list[int], 343 reference_ids: Optional[list[int]] = None, 344 source_ids: Optional[list[int]] = None, 345 ) -> list[int]: 346 """Prepare a model input for MetricX. 347 348 Args: 349 hypothesis_ids (str): Hypothesis token IDs. 350 reference_ids (str, optional): Reference token IDs. 351 source_ids (str, optional): Source token IDs. 352 353 Returns: 354 str: Input string. 355 """ 356 input_ids: list[int] = [] 357 match self.metricx_version: 358 case self.MetricXVersion.metricx_24: 359 if source_ids is None: 360 raise ValueError("MetricX-24 requires the source text.") 361 input_ids += source_ids + hypothesis_ids 362 if reference_ids is not None: 363 input_ids += reference_ids 364 case self.MetricXVersion.metricx_23: 365 input_ids += hypothesis_ids 366 if self.cfg.model in self.METRICX23_QE_MODELS: 367 if source_ids is None: 368 raise ValueError("MetricX-23-QE requires the source text.") 369 input_ids += source_ids 370 else: 371 if reference_ids is None: 372 raise ValueError("MetricX-23 requires the reference text.") 373 input_ids += reference_ids 374 return input_ids 375 376 def _collate(self, batch_ids: list[EncodedInput]) -> BatchEncoding: 377 """Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It 378 adds special tokens, truncates sequences if overflowing while taking into account the special tokens and 379 manages a moving window (with user defined stride) for overflowing tokens 380 381 Args: 382 batch_ids (list[EncodedInput]): List of tokenized input ids. 383 """ 384 385 batch = {} 386 for input_ids in batch_ids: 387 example = self.tokenizer.prepare_for_model( 388 input_ids, 389 add_special_tokens=False, 390 padding=False, 391 truncation=True, 392 max_length=self.max_length, 393 pad_to_multiple_of=None, 394 return_attention_mask=False, 395 return_tensors=None, 396 ) 397 398 for key, value in example.items(): 399 if key not in batch: 400 batch[key] = [] 401 batch[key].append(value) 402 403 batch = self.tokenizer.pad(batch, padding=True, return_tensors="pt") 404 return batch 405
[docs] 406 def score( 407 self, 408 hypothesis: str, 409 reference: Optional[str] = None, 410 source: Optional[str] = None, 411 ) -> float: 412 """Calculate the score of the given hypothesis. 413 414 Args: 415 hypothesis (str): A hypothesis. 416 reference (str, optional): A reference. 417 source (str, optional): A source. 418 419 Returns: 420 float: The score of the given hypothesis. 421 """ 422 423 batch = self._collate( 424 [ 425 self._concatenate_inputs( 426 self._encode_hypothesis(hypothesis), 427 self._encode_reference(reference) 428 if reference is not None 429 else None, 430 self._encode_source(source) if source is not None else None, 431 ) 432 ] 433 ).to(self.device) 434 model_output = self.scorer(**batch) 435 return model_output.predictions.flatten().tolist()[0]
436
[docs] 437 def scores( 438 self, 439 hypotheses: list[str], 440 references: Optional[list[str]] = None, 441 sources: Optional[list[str]] = None, 442 ) -> Tensor: 443 """Calculate the scores of the given hypothesis. 444 445 Args: 446 hypotheses (list[str]): N hypotheses. 447 references (list[str], optional): N references. 448 sources (list[str], optional): N sources. 449 450 Returns: 451 Tensor: The N scores of the given hypotheses. 452 """ 453 examples: list[list[int]] = [] 454 for i, hyp in enumerate(hypotheses): 455 examples.append( 456 self._concatenate_inputs( 457 self._encode_hypothesis(hyp), 458 self._encode_reference(references[i]) 459 if references is not None 460 else None, 461 self._encode_source(sources[i]) if sources is not None else None, 462 ) 463 ) 464 465 scores = [] 466 with timer.measure("score") as t: 467 t.set_delta_ncalls(len(examples)) 468 for i in range(0, len(examples), self.cfg.batch_size): 469 batch = self._collate(examples[i : i + self.cfg.batch_size]).to( 470 self.device 471 ) 472 model_output = self.scorer(**batch) 473 scores.append(model_output.predictions.flatten()) 474 return torch.cat(scores).view(len(hypotheses))
475
[docs] 476 def pairwise_scores( 477 self, hypotheses: list[str], references: list[str], source: Optional[str] = None 478 ) -> Tensor: 479 """Calculate the pairwise scores. 480 481 Args: 482 hypotheses (list[str]): Hypotheses. 483 references (list[str]): References. 484 source (str, optional): A source. 485 486 Returns: 487 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 488 of hypotheses and `R` is the number of references. 489 """ 490 scores = [] 491 hypotheses_ids = [self._encode_hypothesis(hyp) for hyp in hypotheses] 492 references_ids = [self._encode_reference(ref) for ref in references] 493 source_ids = self._encode_source(source) if source is not None else None 494 pairwise_iter = itertools.product(hypotheses_ids, references_ids) 495 496 while batch := list(itertools.islice(pairwise_iter, self.cfg.batch_size)): 497 with timer.measure("score") as t: 498 t.set_delta_ncalls(len(batch)) 499 batch = self._collate( 500 [ 501 self._concatenate_inputs(hyp_ids, ref_ids, source_ids) 502 for hyp_ids, ref_ids in batch 503 ] 504 ).to(self.device) 505 model_output = self.scorer(**batch) 506 scores.append(model_output.predictions.flatten()) 507 return torch.cat(scores).view(len(hypotheses), len(references))
508
[docs] 509 def corpus_score( 510 self, 511 hypotheses: list[str], 512 references_lists: Optional[list[list[str]]] = None, 513 sources: Optional[list[str]] = None, 514 ) -> float: 515 """Calculate the corpus-level score. 516 517 Args: 518 hypotheses (list[str]): Hypotheses. 519 references (list[list[str]], optional): Lists of references. 520 sources (list[str], optional): Sources. 521 522 Returns: 523 float: The corpus score. 524 """ 525 scores: list[Tensor] = [] 526 if references_lists is None: 527 if sources is None: 528 raise ValueError( 529 "`sources` must be given when `references_lists` is None." 530 ) 531 532 for i in range(0, len(hypotheses), self.cfg.batch_size): 533 scores.append( 534 self.scores( 535 hypotheses[i : i + self.cfg.batch_size], 536 None, 537 sources[i : i + self.cfg.batch_size], 538 ) 539 .float() 540 .cpu() 541 ) 542 else: 543 for references in references_lists: 544 for i in range(0, len(hypotheses), self.cfg.batch_size): 545 scores.append( 546 self.scores( 547 hypotheses[i : i + self.cfg.batch_size], 548 references[i : i + self.cfg.batch_size], 549 sources[i : i + self.cfg.batch_size] 550 if sources is not None 551 else None, 552 ) 553 .float() 554 .cpu() 555 ) 556 return torch.cat(scores).mean().item()