Source code for mbrs.metrics.cometkiwi

  1from __future__ import annotations
  2
  3from dataclasses import dataclass
  4
  5import torch
  6from comet import download_model, load_from_checkpoint
  7
  8from mbrs import utils
  9
 10from . import MetricReferenceless, register
 11
 12
[docs] 13@register("cometkiwi") 14class MetricCOMETkiwi(MetricReferenceless): 15 """COMETkiwi metric class.""" 16
[docs] 17 @dataclass 18 class Config(MetricReferenceless.Config): 19 """COMETkiwi metric configuration. 20 21 - model (str): Model name or path. 22 - batch_size (int): Batch size. 23 - fp16 (bool): Use float16 for the forward computation. 24 - bf16 (bool): Use bfloat16 for the forward computation. 25 - cpu (bool): Use CPU for the forward computation. 26 """ 27 28 model: str = "Unbabel/wmt22-cometkiwi-da" 29 batch_size: int = 64 30 fp16: bool = False 31 bf16: bool = False 32 cpu: bool = False
33 34 def __init__(self, cfg: MetricCOMETkiwi.Config): 35 super().__init__(cfg) 36 self.scorer = load_from_checkpoint(download_model(cfg.model)) 37 self.scorer.eval() 38 for param in self.scorer.parameters(): 39 param.requires_grad = False 40 41 if not cfg.cpu and torch.cuda.is_available(): 42 if cfg.fp16: 43 self.scorer = self.scorer.half() 44 elif cfg.bf16: 45 self.scorer = self.scorer.bfloat16() 46 self.scorer = self.scorer.cuda() 47 48 @property 49 def device(self) -> torch.device: 50 """Returns the device of the model.""" 51 return self.scorer.device 52
[docs] 53 def score(self, hypothesis: str, source: str) -> float: 54 """Calculate the score of the given hypothesis. 55 56 Args: 57 hypothesis (str): A hypothesis. 58 source (str): A source. 59 60 Returns: 61 float: The score of the given hypothesis. 62 """ 63 return self.scores([hypothesis], [source]).item()
64
[docs] 65 def scores(self, hypotheses: list[str], sources: list[str]) -> torch.Tensor: 66 """Calculate the scores of hypotheses. 67 68 Args: 69 hypotheses (list[str]): N hypotheses. 70 source (list[str]): N sources. 71 72 Returns: 73 torch.Tensor: N scores of the given hypotheses. 74 """ 75 data = [{"src": src, "mt": hyp} for hyp, src in zip(hypotheses, sources)] 76 scores = [] 77 for i in range(0, len(data), self.cfg.batch_size): 78 batch = self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size]) 79 batch = utils.to_device(batch, self.device) 80 model_output = self.scorer.predict_step(batch) 81 scores.append(model_output.scores) 82 return torch.cat(scores).view(len(hypotheses))
83
[docs] 84 def corpus_score(self, hypotheses: list[str], sources: list[str]) -> float: 85 """Calculate the corpus-level score. 86 87 Args: 88 hypotheses (list[str]): Hypotheses. 89 source (list[str]): Sources. 90 91 Returns: 92 float: The corpus score. 93 """ 94 scores = [] 95 for i in range(0, len(hypotheses), self.cfg.batch_size): 96 scores.append( 97 self.scores( 98 hypotheses[i : i + self.cfg.batch_size], 99 sources[i : i + self.cfg.batch_size], 100 ) 101 .float() 102 .cpu() 103 ) 104 return torch.cat(scores).mean().item()