Source code for mbrs.metrics.cometkiwi
1from __future__ import annotations
2
3from dataclasses import dataclass
4
5import torch
6from comet import download_model, load_from_checkpoint
7
8from mbrs import utils
9
10from . import MetricReferenceless, register
11
12
[docs]
13@register("cometkiwi")
14class MetricCOMETkiwi(MetricReferenceless):
15 """COMETkiwi metric class."""
16
[docs]
17 @dataclass
18 class Config(MetricReferenceless.Config):
19 """COMETkiwi metric configuration.
20
21 - model (str): Model name or path.
22 - batch_size (int): Batch size.
23 - fp16 (bool): Use float16 for the forward computation.
24 - bf16 (bool): Use bfloat16 for the forward computation.
25 - cpu (bool): Use CPU for the forward computation.
26 """
27
28 model: str = "Unbabel/wmt22-cometkiwi-da"
29 batch_size: int = 64
30 fp16: bool = False
31 bf16: bool = False
32 cpu: bool = False
33
34 def __init__(self, cfg: MetricCOMETkiwi.Config):
35 super().__init__(cfg)
36 self.scorer = load_from_checkpoint(download_model(cfg.model))
37 self.scorer.eval()
38 for param in self.scorer.parameters():
39 param.requires_grad = False
40
41 if not cfg.cpu and torch.cuda.is_available():
42 if cfg.fp16:
43 self.scorer = self.scorer.half()
44 elif cfg.bf16:
45 self.scorer = self.scorer.bfloat16()
46 self.scorer = self.scorer.cuda()
47
48 @property
49 def device(self) -> torch.device:
50 """Returns the device of the model."""
51 return self.scorer.device
52
[docs]
53 def score(self, hypothesis: str, source: str) -> float:
54 """Calculate the score of the given hypothesis.
55
56 Args:
57 hypothesis (str): A hypothesis.
58 source (str): A source.
59
60 Returns:
61 float: The score of the given hypothesis.
62 """
63 return self.scores([hypothesis], [source]).item()
64
[docs]
65 def scores(self, hypotheses: list[str], sources: list[str]) -> torch.Tensor:
66 """Calculate the scores of hypotheses.
67
68 Args:
69 hypotheses (list[str]): N hypotheses.
70 source (list[str]): N sources.
71
72 Returns:
73 torch.Tensor: N scores of the given hypotheses.
74 """
75 data = [{"src": src, "mt": hyp} for hyp, src in zip(hypotheses, sources)]
76 scores = []
77 for i in range(0, len(data), self.cfg.batch_size):
78 batch = self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size])
79 batch = utils.to_device(batch, self.device)
80 model_output = self.scorer.predict_step(batch)
81 scores.append(model_output.scores)
82 return torch.cat(scores).view(len(hypotheses))
83
[docs]
84 def corpus_score(self, hypotheses: list[str], sources: list[str]) -> float:
85 """Calculate the corpus-level score.
86
87 Args:
88 hypotheses (list[str]): Hypotheses.
89 source (list[str]): Sources.
90
91 Returns:
92 float: The corpus score.
93 """
94 scores = []
95 for i in range(0, len(hypotheses), self.cfg.batch_size):
96 scores.append(
97 self.scores(
98 hypotheses[i : i + self.cfg.batch_size],
99 sources[i : i + self.cfg.batch_size],
100 )
101 .float()
102 .cpu()
103 )
104 return torch.cat(scores).mean().item()