Source code for mbrs.metrics.comet

  1from __future__ import annotations
  2
  3from dataclasses import dataclass
  4from typing import Optional, Sequence
  5
  6import torch
  7from comet import download_model, load_from_checkpoint
  8from torch import Tensor
  9
 10from mbrs.modules.kmeans import Kmeans
 11
 12from . import MetricAggregatableCache, register
 13
 14
[docs] 15@register("comet") 16class MetricCOMET(MetricAggregatableCache): 17 """COMET metric class.""" 18
[docs] 19 @dataclass 20 class Config(MetricAggregatableCache.Config): 21 """COMET metric configuration. 22 23 - model (str): Model name or path. 24 - batch_size (int): Batch size. 25 - fp16 (bool): Use float16 for the forward computation. 26 - bf16 (bool): Use bfloat16 for the forward computation. 27 - cpu (bool): Use CPU for the forward computation. 28 """ 29 30 model: str = "Unbabel/wmt22-comet-da" 31 batch_size: int = 64 32 fp16: bool = False 33 bf16: bool = False 34 cpu: bool = False
35 36 def __init__(self, cfg: MetricCOMET.Config): 37 super().__init__(cfg) 38 self.scorer = load_from_checkpoint(download_model(cfg.model)) 39 self.scorer.eval() 40 for param in self.scorer.parameters(): 41 param.requires_grad = False 42 43 if not cfg.cpu and torch.cuda.is_available(): 44 if cfg.fp16: 45 self.scorer = self.scorer.half() 46 elif cfg.bf16: 47 self.scorer = self.scorer.bfloat16() 48 self.scorer = self.scorer.cuda() 49
[docs] 50 @dataclass 51 class Cache(MetricAggregatableCache.Cache): 52 """Intermediate representations of sentences. 53 54 - embeddings (Tensor): Sentence embeddings of shape `(N, D)`, where `N` 55 is the number of sentences and `D` is a size of the embedding 56 dimension. 57 """ 58 59 embeddings: Tensor 60 61 def __len__(self) -> int: 62 """Return the length of cache.""" 63 return len(self.embeddings) 64 65 def __getitem__( 66 self, key: int | Sequence[int] | slice | Tensor 67 ) -> MetricCOMET.Cache: 68 """Get the items.""" 69 return type(self)(self.embeddings[key]) 70
[docs] 71 def repeat(self, n: int) -> MetricCOMET.Cache: 72 """Repeat the representations by n times. 73 74 Args: 75 n (int): The number of repetition. 76 77 Returns: 78 Cache: The repeated cache. 79 """ 80 return type(self)(self.embeddings.repeat((n, 1)))
81
[docs] 82 def aggregate( 83 self, reference_lprobs: Optional[Tensor] = None 84 ) -> MetricCOMET.Cache: 85 """Aggregate the cached representations. 86 87 Args: 88 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 89 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 90 91 Returns: 92 Cache: An aggregated representation. 93 """ 94 if reference_lprobs is not None: 95 aggregated_embedding = ( 96 self.embeddings 97 * reference_lprobs.to(self.embeddings) 98 .softmax(dim=-1, dtype=torch.float32) 99 .to(self.embeddings)[:, None] 100 ).sum(dim=0, keepdim=True) 101 else: 102 aggregated_embedding = self.embeddings.mean(dim=0, keepdim=True) 103 return type(self)(aggregated_embedding)
104
[docs] 105 def cluster( 106 self, kmeans: Kmeans 107 ) -> tuple[MetricAggregatableCache.Cache, Tensor]: 108 """Cluster the cached representations. 109 110 Args: 111 kmeans (Kmeans): k-means class to perform clustering. 112 113 Returns: 114 tuple[Cache, Tensor]: 115 - Cache: Centroid representations. 116 - Tensor: N assigned IDs. 117 """ 118 centroids, assigns = kmeans.train(self.embeddings) 119 return type(self)(centroids), assigns
120 121 @property 122 def embed_dim(self) -> int: 123 """Return the size of embedding dimension.""" 124 return self.scorer.encoder.output_units 125 126 @property 127 def device(self) -> torch.device: 128 """Returns the device of the model.""" 129 return self.scorer.device 130
[docs] 131 def encode(self, sentences: list[str]) -> Cache: 132 """Encode the given sentences into their intermediate representations. 133 134 Args: 135 sentences (list[str]): Input sentences. 136 137 Returns: 138 MetricCOMET.Cache: Intermediate representations. 139 """ 140 batches = [ 141 self.scorer.encoder.prepare_sample(sentences[i : i + self.cfg.batch_size]) 142 for i in range(0, len(sentences), self.cfg.batch_size) 143 ] 144 embeddings = [] 145 for batch in batches: 146 emb = self.scorer.get_sentence_embedding(**batch.to(self.scorer.device)) 147 if self.scorer.device.type != "cpu": 148 if self.cfg.fp16: 149 emb = emb.half() 150 elif self.cfg.bf16: 151 emb = emb.bfloat16() 152 else: 153 emb = emb.float() 154 embeddings.append(emb) 155 return self.Cache(torch.vstack(embeddings))
156
[docs] 157 def out_proj( 158 self, hypotheses_ir: Cache, references_ir: Cache, sources_ir: Cache 159 ) -> Tensor: 160 """Forward the output projection layer. 161 162 Args: 163 hypotheses_ir (Cache): N intermediate representations of hypotheses. 164 references_ir (Cache): N intermediate representations of references. 165 sources_ir (Cache, optional): N intermediate representations of sources. 166 167 Returns: 168 Tensor: N scores. 169 """ 170 return self.scorer.estimate( 171 sources_ir.embeddings, hypotheses_ir.embeddings, references_ir.embeddings 172 )["score"]
173
[docs] 174 def corpus_score( 175 self, 176 hypotheses: list[str], 177 references_lists: list[list[str]], 178 sources: Optional[list[str]] = None, 179 ) -> float: 180 """Calculate the corpus-level score. 181 182 Args: 183 hypotheses (list[str]): Hypotheses. 184 references_lists (list[list[str]]): Lists of references. 185 sources (list[str], optional): Sources. 186 187 Returns: 188 float: The corpus score. 189 190 Raises: 191 ValueError: Raise this error when sources are not given. 192 """ 193 if sources is None: 194 raise ValueError("COMET requires the sources.") 195 196 scores = [] 197 for references in references_lists: 198 for i in range(0, len(hypotheses), self.cfg.batch_size): 199 scores.append( 200 self.scores( 201 hypotheses[i : i + self.cfg.batch_size], 202 references[i : i + self.cfg.batch_size], 203 sources[i : i + self.cfg.batch_size], 204 ) 205 .float() 206 .cpu() 207 ) 208 return torch.cat(scores).mean().item()