mbrs.metrics.comet module#
- class mbrs.metrics.comet.MetricCOMET(cfg: Config)[source]#
Bases:
MetricAggregatableCacheCOMET metric class.
- class Cache(embeddings: Tensor)[source]#
Bases:
CacheIntermediate representations of sentences.
- embeddings (Tensor): Sentence embeddings of shape (N, D), where N
is the number of sentences and D is a size of the embedding dimension.
- aggregate(reference_lprobs: Tensor | None = None) Cache[source]#
Aggregate the cached representations.
- Parameters:
reference_lprobs (Tensor, optional) – Log-probabilities for each reference sample. The shape must be (len(references),). See https://arxiv.org/abs/2311.05263.
- Returns:
An aggregated representation.
- Return type:
- embeddings: Tensor#
- class Config(model: str = 'Unbabel/wmt22-comet-da', batch_size: int = 64, fp16: bool = False, bf16: bool = False, cpu: bool = False)[source]#
Bases:
ConfigCOMET metric configuration.
model (str): Model name or path.
batch_size (int): Batch size.
fp16 (bool): Use float16 for the forward computation.
bf16 (bool): Use bfloat16 for the forward computation.
cpu (bool): Use CPU for the forward computation.
- corpus_score(hypotheses: list[str], references_lists: list[list[str]], sources: list[str] | None = None) float[source]#
Calculate the corpus-level score.
- property device: device#
Returns the device of the model.
- encode(sentences: list[str]) Cache[source]#
Encode the given sentences into their intermediate representations.
- Parameters:
- Returns:
Intermediate representations.
- Return type: