Source code for mbrs.decoders.centroid_mbr

  1from __future__ import annotations
  2
  3from dataclasses import dataclass, field
  4from typing import Optional
  5
  6import torch
  7from torch import Tensor
  8
  9from mbrs import functional, timer
 10from mbrs.metrics import MetricAggregatableCache
 11from mbrs.modules.kmeans import Kmeans
 12from mbrs.selectors import SELECTOR_NBEST, Selector
 13
 14from . import register
 15from .mbr import DecoderMBR
 16
 17
[docs] 18@register("centroid_mbr") 19class DecoderCentroidMBR(DecoderMBR): 20 """Centroid-Based MBR decoder class. 21 22 - Time complexity: O(Nk) 23 - Space complexity: O(Nk) 24 25 where k << N. 26 27 References: 28 H. Deguchi et al., 2024. 29 "Centroid-Based Efficient Minimum Bayes Risk Decoding". 30 https://aclanthology.org/2024.findings-acl.654 31 """ 32 33 def __init__( 34 self, 35 cfg: DecoderCentroidMBR.Config, 36 metric: MetricAggregatableCache, 37 selector: Selector = SELECTOR_NBEST, 38 ) -> None: 39 super().__init__(cfg, metric, selector=selector) 40 self.kmeans = Kmeans(cfg.kmeans) 41 42 cfg: Config 43
[docs] 44 @dataclass 45 class Config(DecoderMBR.Config): 46 """Configuration for the decoder. 47 48 - kmeans (Kmeans.Config): Configuration for k-means. 49 - count_weight: (bool) Weight the scores with counts. 50 """ 51 52 kmeans: Kmeans.Config = field(default_factory=Kmeans.Config) 53 count_weight: bool = False
54
[docs] 55 def decode( 56 self, 57 hypotheses: list[str], 58 references: list[str], 59 source: Optional[str] = None, 60 nbest: int = 1, 61 reference_lprobs: Optional[Tensor] = None, 62 ) -> DecoderCentroidMBR.Output: 63 """Select the n-best hypotheses based on the strategy. 64 65 Args: 66 hypotheses (list[str]): Hypotheses. 67 references (list[str]): References. 68 source (str, optional): A source. 69 nbest (int): Return the n-best hypotheses. 70 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample. 71 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`. 72 73 Returns: 74 DecoderCentroidMBR.Output: The n-best hypotheses. 75 """ 76 assert isinstance(self.metric, MetricAggregatableCache) 77 78 with timer.measure("encode/hypotheses"): 79 hypotheses_ir = self.metric.encode(hypotheses) 80 if hypotheses == references: 81 references_ir: MetricAggregatableCache.Cache = hypotheses_ir 82 else: 83 with timer.measure("encode/references"): 84 references_ir: MetricAggregatableCache.Cache = self.metric.encode( 85 references 86 ) 87 if source is None: 88 source_ir = None 89 else: 90 with timer.measure("encode/source"): 91 source_ir = self.metric.encode([source]) 92 centroids, assigns = references_ir.cluster(self.kmeans) 93 94 lprobs = None 95 if self.cfg.count_weight: 96 centroid_ids, counts_nonzero = assigns.unique(return_counts=True) 97 counts = centroid_ids.new_zeros(len(centroids)) 98 counts[centroid_ids] = counts_nonzero 99 lprobs = counts.log() 100 elif reference_lprobs is not None: 101 # Accumurate the log-probabilities for each centroid by logsumexp. 102 lprobs = ( 103 torch.zeros(len(centroids), dtype=torch.float32, device=assigns.device) 104 .scatter_add( 105 dim=-1, 106 index=assigns.unique(), 107 src=reference_lprobs.to(assigns.device).softmax( 108 dim=-1, dtype=torch.float32 109 ), 110 ) 111 .log() 112 ) 113 114 with timer.measure("expectation"): 115 pairwise_scores = self.metric.pairwise_scores_from_ir( 116 hypotheses_ir, centroids, source_ir 117 ) 118 if lprobs is not None: 119 lprobs = lprobs.to(pairwise_scores) 120 expected_scores = functional.expectation(pairwise_scores, lprobs=lprobs) 121 122 selector_outputs = self.select( 123 hypotheses, expected_scores, nbest=nbest, source=source 124 ) 125 return ( 126 self.Output( 127 idx=selector_outputs.idx, 128 sentence=selector_outputs.sentence, 129 score=selector_outputs.score, 130 ) 131 | selector_outputs 132 )