Source code for mbrs.metrics.comet
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Optional, Sequence
5
6import torch
7from comet import download_model, load_from_checkpoint
8from torch import Tensor
9
10from mbrs.modules.kmeans import Kmeans
11
12from . import MetricAggregatableCache, register
13
14
[docs]
15@register("comet")
16class MetricCOMET(MetricAggregatableCache):
17 """COMET metric class."""
18
[docs]
19 @dataclass
20 class Config(MetricAggregatableCache.Config):
21 """COMET metric configuration.
22
23 - model (str): Model name or path.
24 - batch_size (int): Batch size.
25 - fp16 (bool): Use float16 for the forward computation.
26 - bf16 (bool): Use bfloat16 for the forward computation.
27 - cpu (bool): Use CPU for the forward computation.
28 """
29
30 model: str = "Unbabel/wmt22-comet-da"
31 batch_size: int = 64
32 fp16: bool = False
33 bf16: bool = False
34 cpu: bool = False
35
36 def __init__(self, cfg: MetricCOMET.Config):
37 super().__init__(cfg)
38 self.scorer = load_from_checkpoint(download_model(cfg.model))
39 self.scorer.eval()
40 for param in self.scorer.parameters():
41 param.requires_grad = False
42
43 if not cfg.cpu and torch.cuda.is_available():
44 if cfg.fp16:
45 self.scorer = self.scorer.half()
46 elif cfg.bf16:
47 self.scorer = self.scorer.bfloat16()
48 self.scorer = self.scorer.cuda()
49
[docs]
50 @dataclass
51 class Cache(MetricAggregatableCache.Cache):
52 """Intermediate representations of sentences.
53
54 - embeddings (Tensor): Sentence embeddings of shape `(N, D)`, where `N`
55 is the number of sentences and `D` is a size of the embedding
56 dimension.
57 """
58
59 embeddings: Tensor
60
61 def __len__(self) -> int:
62 """Return the length of cache."""
63 return len(self.embeddings)
64
65 def __getitem__(
66 self, key: int | Sequence[int] | slice | Tensor
67 ) -> MetricCOMET.Cache:
68 """Get the items."""
69 return type(self)(self.embeddings[key])
70
[docs]
71 def repeat(self, n: int) -> MetricCOMET.Cache:
72 """Repeat the representations by n times.
73
74 Args:
75 n (int): The number of repetition.
76
77 Returns:
78 Cache: The repeated cache.
79 """
80 return type(self)(self.embeddings.repeat((n, 1)))
81
[docs]
82 def aggregate(
83 self, reference_lprobs: Optional[Tensor] = None
84 ) -> MetricCOMET.Cache:
85 """Aggregate the cached representations.
86
87 Args:
88 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
89 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
90
91 Returns:
92 Cache: An aggregated representation.
93 """
94 if reference_lprobs is not None:
95 aggregated_embedding = (
96 self.embeddings
97 * reference_lprobs.to(self.embeddings)
98 .softmax(dim=-1, dtype=torch.float32)
99 .to(self.embeddings)[:, None]
100 ).sum(dim=0, keepdim=True)
101 else:
102 aggregated_embedding = self.embeddings.mean(dim=0, keepdim=True)
103 return type(self)(aggregated_embedding)
104
[docs]
105 def cluster(
106 self, kmeans: Kmeans
107 ) -> tuple[MetricAggregatableCache.Cache, Tensor]:
108 """Cluster the cached representations.
109
110 Args:
111 kmeans (Kmeans): k-means class to perform clustering.
112
113 Returns:
114 tuple[Cache, Tensor]:
115 - Cache: Centroid representations.
116 - Tensor: N assigned IDs.
117 """
118 centroids, assigns = kmeans.train(self.embeddings)
119 return type(self)(centroids), assigns
120
121 @property
122 def embed_dim(self) -> int:
123 """Return the size of embedding dimension."""
124 return self.scorer.encoder.output_units
125
126 @property
127 def device(self) -> torch.device:
128 """Returns the device of the model."""
129 return self.scorer.device
130
[docs]
131 def encode(self, sentences: list[str]) -> Cache:
132 """Encode the given sentences into their intermediate representations.
133
134 Args:
135 sentences (list[str]): Input sentences.
136
137 Returns:
138 MetricCOMET.Cache: Intermediate representations.
139 """
140 batches = [
141 self.scorer.encoder.prepare_sample(sentences[i : i + self.cfg.batch_size])
142 for i in range(0, len(sentences), self.cfg.batch_size)
143 ]
144 embeddings = []
145 for batch in batches:
146 emb = self.scorer.get_sentence_embedding(**batch.to(self.scorer.device))
147 if self.scorer.device.type != "cpu":
148 if self.cfg.fp16:
149 emb = emb.half()
150 elif self.cfg.bf16:
151 emb = emb.bfloat16()
152 else:
153 emb = emb.float()
154 embeddings.append(emb)
155 return self.Cache(torch.vstack(embeddings))
156
[docs]
157 def out_proj(
158 self, hypotheses_ir: Cache, references_ir: Cache, sources_ir: Cache
159 ) -> Tensor:
160 """Forward the output projection layer.
161
162 Args:
163 hypotheses_ir (Cache): N intermediate representations of hypotheses.
164 references_ir (Cache): N intermediate representations of references.
165 sources_ir (Cache, optional): N intermediate representations of sources.
166
167 Returns:
168 Tensor: N scores.
169 """
170 return self.scorer.estimate(
171 sources_ir.embeddings, hypotheses_ir.embeddings, references_ir.embeddings
172 )["score"]
173
[docs]
174 def corpus_score(
175 self,
176 hypotheses: list[str],
177 references_lists: list[list[str]],
178 sources: Optional[list[str]] = None,
179 ) -> float:
180 """Calculate the corpus-level score.
181
182 Args:
183 hypotheses (list[str]): Hypotheses.
184 references_lists (list[list[str]]): Lists of references.
185 sources (list[str], optional): Sources.
186
187 Returns:
188 float: The corpus score.
189
190 Raises:
191 ValueError: Raise this error when sources are not given.
192 """
193 if sources is None:
194 raise ValueError("COMET requires the sources.")
195
196 scores = []
197 for references in references_lists:
198 for i in range(0, len(hypotheses), self.cfg.batch_size):
199 scores.append(
200 self.scores(
201 hypotheses[i : i + self.cfg.batch_size],
202 references[i : i + self.cfg.batch_size],
203 sources[i : i + self.cfg.batch_size],
204 )
205 .float()
206 .cpu()
207 )
208 return torch.cat(scores).mean().item()