Source code for mbrs.metrics.bleurt

  1from __future__ import annotations
  2
  3import itertools
  4from dataclasses import dataclass
  5from typing import Optional
  6
  7import torch
  8import transformers
  9from bleurt_pytorch import (
 10    BleurtForSequenceClassification,
 11    BleurtTokenizer,
 12)
 13from torch import Tensor
 14from transformers.tokenization_utils import BatchEncoding, EncodedInputPair
 15
 16from mbrs import timer
 17
 18from . import Metric, register
 19
 20transformers.logging.set_verbosity_error()
 21
 22
[docs] 23@register("bleurt") 24class MetricBLEURT(Metric): 25 """BLEURT metric class. 26 27 We employ the PyTorch port version to implement this metric instead of the original version: 28 https://github.com/lucadiliello/bleurt-pytorch 29 (thanks to @lucadiliello) 30 31 Available checkpoints: 32 33 - lucadiliello/BLEURT-20 34 - lucadiliello/BLEURT-20-D12 35 - lucadiliello/BLEURT-20-D3 36 - lucadiliello/BLEURT-20-D6 37 - lucadiliello/bleurt-base-128 38 - lucadiliello/bleurt-base-512 39 - lucadiliello/bleurt-large-128 40 - lucadiliello/bleurt-large-512 41 - lucadiliello/bleurt-tiny-128 42 - lucadiliello/bleurt-tiny-512 43 """ 44 45 scorer: BleurtForSequenceClassification 46
[docs] 47 @dataclass 48 class Config(Metric.Config): 49 """BLEURT metric configuration. 50 51 - model (str): Model name or path. 52 - batch_size (int): Batch size. 53 - fp16 (bool): Use float16 for the forward computation. 54 - bf16 (bool): Use bfloat16 for the forward computation. 55 - cpu (bool): Use CPU for the forward computation. 56 """ 57 58 model: str = "lucadiliello/BLEURT-20-D12" 59 batch_size: int = 64 60 fp16: bool = False 61 bf16: bool = False 62 cpu: bool = False
63 64 def __init__(self, cfg: MetricBLEURT.Config): 65 super().__init__(cfg) 66 self.scorer = BleurtForSequenceClassification.from_pretrained(cfg.model) 67 self.tokenizer = BleurtTokenizer.from_pretrained(cfg.model) 68 self.max_length = self.tokenizer.max_model_input_sizes[ 69 self.tokenizer.name_or_path 70 ] 71 self.scorer.eval() 72 for param in self.scorer.parameters(): 73 param.requires_grad = False 74 75 if not cfg.cpu and torch.cuda.is_available(): 76 if cfg.fp16: 77 self.scorer = self.scorer.half() 78 elif cfg.bf16: 79 self.scorer = self.scorer.bfloat16() 80 self.scorer = self.scorer.cuda() 81 82 @property 83 def device(self) -> torch.device: 84 """Returns the device of the model.""" 85 return self.scorer.device 86
[docs] 87 def score(self, hypothesis: str, reference: str, *_, **__) -> float: 88 """Calculate the score of the given hypothesis. 89 90 Args: 91 hypothesis (str): A hypothesis. 92 reference (str): A reference. 93 94 Returns: 95 float: The score of the given hypothesis. 96 """ 97 batch = self.tokenizer( 98 [reference], 99 [hypothesis], 100 truncation=True, 101 padding=True, 102 max_length=self.max_length, 103 return_tensors="pt", 104 ).to(self.device) 105 model_output = self.scorer(**batch) 106 return model_output.logits.flatten().tolist()[0]
107
[docs] 108 def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor: 109 """Calculate the scores of the given hypothesis. 110 111 Args: 112 hypotheses (list[str]): N hypotheses. 113 references (list[str]): N references. 114 115 Returns: 116 Tensor: The N scores of the given hypotheses. 117 """ 118 119 scores = [] 120 with timer.measure("score") as t: 121 t.set_delta_ncalls(len(hypotheses)) 122 for i in range(0, len(hypotheses), self.cfg.batch_size): 123 batch = self.tokenizer( 124 references[i : i + self.cfg.batch_size], 125 hypotheses[i : i + self.cfg.batch_size], 126 truncation=True, 127 padding=True, 128 max_length=self.max_length, 129 return_tensors="pt", 130 ).to(self.device) 131 model_output = self.scorer(**batch) 132 scores.append(model_output.logits.flatten()) 133 return torch.cat(scores).view(len(hypotheses))
134 135 def __collate(self, batch_ids_pairs: list[EncodedInputPair]) -> BatchEncoding: 136 """ 137 Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It 138 adds special tokens, truncates sequences if overflowing while taking into account the special tokens and 139 manages a moving window (with user defined stride) for overflowing tokens 140 141 Args: 142 batch_ids_pairs (list[EncodedInputPair]): List of tokenized input ids or input ids pairs. 143 """ 144 145 batch = {} 146 for first_ids, second_ids in batch_ids_pairs: 147 example = self.tokenizer.prepare_for_model( 148 first_ids, 149 second_ids, 150 add_special_tokens=True, 151 padding=False, 152 truncation=True, 153 max_length=self.max_length, 154 pad_to_multiple_of=None, 155 return_attention_mask=False, 156 return_tensors=None, 157 ) 158 159 for key, value in example.items(): 160 if key not in batch: 161 batch[key] = [] 162 batch[key].append(value) 163 164 batch = self.tokenizer.pad( 165 batch, padding=True, max_length=self.max_length, return_tensors="pt" 166 ) 167 return batch 168
[docs] 169 def pairwise_scores( 170 self, hypotheses: list[str], references: list[str], *_, **__ 171 ) -> Tensor: 172 """Calculate the pairwise scores. 173 174 Args: 175 hypotheses (list[str]): Hypotheses. 176 references (list[str]): References. 177 178 Returns: 179 Tensor: Score matrix of shape `(H, R)`, where `H` is the number 180 of hypotheses and `R` is the number of references. 181 """ 182 scores = [] 183 hypotheses_ids = [ 184 self.tokenizer.encode(h, add_special_tokens=False) for h in hypotheses 185 ] 186 references_ids = [ 187 self.tokenizer.encode(r, add_special_tokens=False) for r in references 188 ] 189 pairwise_iter = itertools.product(references_ids, hypotheses_ids) 190 191 while batch := list(itertools.islice(pairwise_iter, self.cfg.batch_size)): 192 with timer.measure("score") as t: 193 t.set_delta_ncalls(len(batch)) 194 batch = self.__collate(batch).to(self.device) 195 model_output = self.scorer(**batch) 196 scores.append(model_output.logits.flatten()) 197 return torch.cat(scores).view(len(references), len(hypotheses)).transpose(0, 1)
198
[docs] 199 def corpus_score( 200 self, 201 hypotheses: list[str], 202 references_lists: list[list[str]], 203 sources: Optional[list[str]] = None, 204 ) -> float: 205 """Calculate the corpus-level score. 206 207 Args: 208 hypotheses (list[str]): Hypotheses. 209 references_lists (list[list[str]]): Lists of references. 210 sources (list[str], optional): Sources. 211 212 Returns: 213 float: The corpus score. 214 """ 215 scores = [] 216 for references in references_lists: 217 for i in range(0, len(hypotheses), self.cfg.batch_size): 218 scores.append( 219 self.scores( 220 hypotheses[i : i + self.cfg.batch_size], 221 references[i : i + self.cfg.batch_size], 222 ) 223 .float() 224 .cpu() 225 ) 226 return torch.cat(scores).mean().item()