1from __future__ import annotations
2
3import os
4from dataclasses import dataclass
5from typing import Optional
6
7import comet.encoders
8import torch
9from comet import download_model, load_from_checkpoint
10from comet.encoders.base import Encoder
11from comet.encoders.bert import BERTEncoder
12from comet.models import XCOMETMetric
13from huggingface_hub import PyTorchModelHubMixin
14from torch import Tensor, nn
15from transformers import AutoConfig, AutoModel, AutoTokenizer
16from transformers.models.deberta_v2 import modeling_deberta_v2
17
18from mbrs import timer, utils
19
20from . import Metric, register
21
22
[docs]
23class DeBERTaEncoder(BERTEncoder):
24 """DeBERTa encoder.
25
26 Args:
27 pretrained_model (str): Pretrained model from hugging face.
28 load_pretrained_weights (bool): If set to True loads the pretrained weights
29 from Hugging Face
30 local_files_only (bool): Whether or not to only look at local files.
31 """
32
33 def __init__(
34 self,
35 pretrained_model: str,
36 load_pretrained_weights: bool = True,
37 local_files_only: bool = False,
38 ) -> None:
39 super(Encoder, self).__init__()
40 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
41 self.tokenizer = AutoTokenizer.from_pretrained(
42 pretrained_model, local_files_only=local_files_only
43 )
44 if load_pretrained_weights:
45 self.model = AutoModel.from_pretrained(pretrained_model)
46 else:
47 self.model = AutoModel.from_config(
48 AutoConfig.from_pretrained(
49 pretrained_model, local_files_only=local_files_only
50 ),
51 )
52 self.model.encoder.output_hidden_states = True
53
54 self.model.encoder.layer = nn.ModuleList(
55 [
56 modeling_deberta_v2.DebertaV2Layer(
57 AutoConfig.from_pretrained(pretrained_model)
58 )
59 for _ in range(self.model.config.num_hidden_layers)
60 ]
61 )
62
[docs]
63 @classmethod
64 def from_pretrained(
65 cls,
66 pretrained_model: str,
67 load_pretrained_weights: bool = True,
68 local_files_only: bool = False,
69 ) -> Encoder:
70 """Function that loads a pretrained encoder from Hugging Face.
71
72 Args:
73 pretrained_model (str):Name of the pretrain model to be loaded.
74 load_pretrained_weights (bool): If set to True loads the pretrained weights
75 from Hugging Face
76 local_files_only (bool): Whether or not to only look at local files.
77
78 Returns:
79 DeBERTaEncoder: DeBERTaEncoder object.
80 """
81 return DeBERTaEncoder(
82 pretrained_model, load_pretrained_weights, local_files_only=local_files_only
83 )
84
[docs]
85 def forward(
86 self,
87 input_ids: torch.Tensor,
88 attention_mask: Optional[torch.Tensor] = None,
89 token_type_ids: Optional[torch.Tensor] = None,
90 **kwargs,
91 ) -> dict[str, torch.Tensor]:
92 if attention_mask is None:
93 attention_mask = torch.ones_like(input_ids)
94
95 model_output = self.model(
96 input_ids=input_ids,
97 attention_mask=attention_mask,
98 token_type_ids=token_type_ids,
99 output_hidden_states=True,
100 )
101 return {
102 "sentemb": model_output.last_hidden_state[:, 0, :],
103 "wordemb": model_output.last_hidden_state,
104 "all_layers": model_output.hidden_states,
105 "attention_mask": attention_mask,
106 }
107
108
[docs]
109class XCOMETLiteMetric(XCOMETMetric, PyTorchModelHubMixin):
110 """xCOMET-Lite model."""
111
112 def __init__(
113 self,
114 encoder_model="DeBERTa",
115 pretrained_model="microsoft/mdeberta-v3-base",
116 word_layer=8,
117 validation_data=[],
118 word_level_training=True,
119 hidden_sizes=(3072, 1024),
120 load_pretrained_weights=False,
121 *args,
122 **kwargs,
123 ):
124 comet.encoders.str2encoder["DeBERTa"] = DeBERTaEncoder
125 super().__init__(
126 encoder_model=encoder_model,
127 pretrained_model=pretrained_model,
128 word_layer=word_layer,
129 layer_transformation="softmax",
130 validation_data=validation_data,
131 word_level_training=word_level_training,
132 hidden_sizes=hidden_sizes,
133 load_pretrained_weights=load_pretrained_weights,
134 )
135
136
[docs]
137@register("xcomet")
138class MetricXCOMET(Metric):
139 """XCOMET metric class.
140
141 Both XCOMET (Guerreiro et al., 2024) and XCOMET-lite (Larionov et al., 2024) are supported.
142
143 Supported models:
144 - Unbabel/XCOMET-XL
145 - Unbabel/XCOMET-XXL
146 - myyycroft/XCOMET-lite
147 """
148
149 scorer: XCOMETMetric
150
[docs]
151 @dataclass
152 class Config(Metric.Config):
153 """XCOMET metric configuration.
154
155 - model (str): Model name or path.
156 - batch_size (int): Batch size.
157 - fp16 (bool): Use float16 for the forward computation.
158 - bf16 (bool): Use bfloat16 for the forward computation.
159 - cpu (bool): Use CPU for the forward computation.
160 """
161
162 model: str = "Unbabel/XCOMET-XL"
163 batch_size: int = 8
164 fp16: bool = False
165 bf16: bool = False
166 cpu: bool = False
167
168 def __init__(self, cfg: MetricXCOMET.Config):
169 super().__init__(cfg)
170 if cfg.model == "myyycroft/XCOMET-lite":
171 self.scorer = XCOMETLiteMetric.from_pretrained(cfg.model)
172 else:
173 self.scorer = load_from_checkpoint(download_model(cfg.model))
174 self.scorer.eval()
175 for param in self.scorer.parameters():
176 param.requires_grad = False
177
178 if not cfg.cpu and torch.cuda.is_available():
179 if cfg.fp16:
180 self.scorer = self.scorer.half()
181 elif cfg.bf16:
182 self.scorer = self.scorer.bfloat16()
183 self.scorer = self.scorer.cuda()
184
185 @property
186 def device(self) -> torch.device:
187 """Returns the device of the model."""
188 return self.scorer.device
189
[docs]
190 def score(
191 self,
192 hypothesis: str,
193 reference: Optional[str] = None,
194 source: Optional[str] = None,
195 ) -> float:
196 """Calculate the score of the given hypothesis.
197
198 Args:
199 hypothesis (str): A hypothesis.
200 reference (str, optional): A reference.
201 source (str, optional): A source.
202
203 Returns:
204 float: The score of the given hypothesis.
205 """
206 inputs = {"mt": hypothesis}
207 if reference is not None:
208 inputs["ref"] = reference
209 if source is not None:
210 inputs["src"] = source
211
212 batch = self.scorer.prepare_for_inference([inputs])
213 batch = utils.to_device(batch, self.device)
214 model_output = self.scorer.predict_step(batch)
215 return model_output.scores.item()
216
[docs]
217 def scores(
218 self,
219 hypotheses: list[str],
220 references: Optional[list[str]] = None,
221 sources: Optional[list[str]] = None,
222 ) -> Tensor:
223 """Calculate the scores of the given hypothesis.
224
225 Args:
226 hypotheses (list[str]): N hypotheses.
227 references (list[str], optional): N references.
228 sources (list[str], optional): N sources.
229
230 Returns:
231 Tensor: The N scores of the given hypotheses.
232 """
233 inputs = [{"mt": hyp} for hyp in hypotheses]
234 if references is not None:
235 for d, ref in zip(inputs, references):
236 d["ref"] = ref
237 if sources is not None:
238 for d, src in zip(inputs, sources):
239 d["src"] = src
240
241 scores = []
242 with timer.measure("score") as t:
243 t.set_delta_ncalls(len(inputs))
244 for i in range(0, len(inputs), self.cfg.batch_size):
245 batch = self.scorer.prepare_for_inference(
246 inputs[i : i + self.cfg.batch_size]
247 )
248 batch = utils.to_device(batch, self.device)
249 model_output = self.scorer.predict_step(batch)
250 scores.append(model_output.scores)
251 return torch.cat(scores).view(len(hypotheses))
252
[docs]
253 def pairwise_scores(
254 self, hypotheses: list[str], references: list[str], source: Optional[str] = None
255 ) -> Tensor:
256 """Calculate the pairwise scores.
257
258 Args:
259 hypotheses (list[str]): Hypotheses.
260 references (list[str]): References.
261 source (str, optional): A source.
262
263 Returns:
264 Tensor: Score matrix of shape `(H, R)`, where `H` is the number
265 of hypotheses and `R` is the number of references.
266 """
267 data = [
268 {"src": source, "mt": hyp, "ref": ref}
269 for hyp in hypotheses
270 for ref in references
271 ]
272 scores = []
273 with timer.measure("score") as t:
274 t.set_delta_ncalls(len(data))
275 for i in range(0, len(data), self.cfg.batch_size):
276 batch = self.scorer.prepare_for_inference(
277 data[i : i + self.cfg.batch_size]
278 )
279 batch = utils.to_device(batch, self.device)
280 model_output = self.scorer.predict_step(batch)
281 scores.append(model_output.scores)
282 return torch.cat(scores).view(len(hypotheses), len(references))
283
[docs]
284 def corpus_score(
285 self,
286 hypotheses: list[str],
287 references_lists: Optional[list[list[str]]] = None,
288 sources: Optional[list[str]] = None,
289 ) -> float:
290 """Calculate the corpus-level score.
291
292 Args:
293 hypotheses (list[str]): Hypotheses.
294 references_lists (list[list[str]], optional): Lists of references.
295 sources (list[str], optional): Sources.
296
297 Returns:
298 float: The corpus score.
299 """
300 scores: list[Tensor] = []
301 if references_lists is None:
302 if sources is None:
303 raise ValueError(
304 "`sources` must be given when `references_lists` is None."
305 )
306
307 for i in range(0, len(hypotheses), self.cfg.batch_size):
308 scores.append(
309 self.scores(
310 hypotheses[i : i + self.cfg.batch_size],
311 None,
312 sources[i : i + self.cfg.batch_size],
313 )
314 .float()
315 .cpu()
316 )
317 else:
318 for references in references_lists:
319 for i in range(0, len(hypotheses), self.cfg.batch_size):
320 scores.append(
321 self.scores(
322 hypotheses[i : i + self.cfg.batch_size],
323 references[i : i + self.cfg.batch_size],
324 sources[i : i + self.cfg.batch_size]
325 if sources is not None
326 else None,
327 )
328 .float()
329 .cpu()
330 )
331 return torch.cat(scores).mean().item()