Source code for mbrs.decoders.centroid_mbr
1from __future__ import annotations
2
3from dataclasses import dataclass, field
4from typing import Optional
5
6import torch
7from torch import Tensor
8
9from mbrs import functional, timer
10from mbrs.metrics import MetricAggregatableCache
11from mbrs.modules.kmeans import Kmeans
12from mbrs.selectors import SELECTOR_NBEST, Selector
13
14from . import register
15from .mbr import DecoderMBR
16
17
[docs]
18@register("centroid_mbr")
19class DecoderCentroidMBR(DecoderMBR):
20 """Centroid-Based MBR decoder class.
21
22 - Time complexity: O(Nk)
23 - Space complexity: O(Nk)
24
25 where k << N.
26
27 References:
28 H. Deguchi et al., 2024.
29 "Centroid-Based Efficient Minimum Bayes Risk Decoding".
30 https://aclanthology.org/2024.findings-acl.654
31 """
32
33 def __init__(
34 self,
35 cfg: DecoderCentroidMBR.Config,
36 metric: MetricAggregatableCache,
37 selector: Selector = SELECTOR_NBEST,
38 ) -> None:
39 super().__init__(cfg, metric, selector=selector)
40 self.kmeans = Kmeans(cfg.kmeans)
41
42 cfg: Config
43
[docs]
44 @dataclass
45 class Config(DecoderMBR.Config):
46 """Configuration for the decoder.
47
48 - kmeans (Kmeans.Config): Configuration for k-means.
49 - count_weight: (bool) Weight the scores with counts.
50 """
51
52 kmeans: Kmeans.Config = field(default_factory=Kmeans.Config)
53 count_weight: bool = False
54
[docs]
55 def decode(
56 self,
57 hypotheses: list[str],
58 references: list[str],
59 source: Optional[str] = None,
60 nbest: int = 1,
61 reference_lprobs: Optional[Tensor] = None,
62 ) -> DecoderCentroidMBR.Output:
63 """Select the n-best hypotheses based on the strategy.
64
65 Args:
66 hypotheses (list[str]): Hypotheses.
67 references (list[str]): References.
68 source (str, optional): A source.
69 nbest (int): Return the n-best hypotheses.
70 reference_lprobs (Tensor, optional): Log-probabilities for each reference sample.
71 The shape must be `(len(references),)`. See `https://arxiv.org/abs/2311.05263`.
72
73 Returns:
74 DecoderCentroidMBR.Output: The n-best hypotheses.
75 """
76 assert isinstance(self.metric, MetricAggregatableCache)
77
78 with timer.measure("encode/hypotheses"):
79 hypotheses_ir = self.metric.encode(hypotheses)
80 if hypotheses == references:
81 references_ir: MetricAggregatableCache.Cache = hypotheses_ir
82 else:
83 with timer.measure("encode/references"):
84 references_ir: MetricAggregatableCache.Cache = self.metric.encode(
85 references
86 )
87 if source is None:
88 source_ir = None
89 else:
90 with timer.measure("encode/source"):
91 source_ir = self.metric.encode([source])
92 centroids, assigns = references_ir.cluster(self.kmeans)
93
94 lprobs = None
95 if self.cfg.count_weight:
96 centroid_ids, counts_nonzero = assigns.unique(return_counts=True)
97 counts = centroid_ids.new_zeros(len(centroids))
98 counts[centroid_ids] = counts_nonzero
99 lprobs = counts.log()
100 elif reference_lprobs is not None:
101 # Accumurate the log-probabilities for each centroid by logsumexp.
102 lprobs = (
103 torch.zeros(len(centroids), dtype=torch.float32, device=assigns.device)
104 .scatter_add(
105 dim=-1,
106 index=assigns.unique(),
107 src=reference_lprobs.to(assigns.device).softmax(
108 dim=-1, dtype=torch.float32
109 ),
110 )
111 .log()
112 )
113
114 with timer.measure("expectation"):
115 pairwise_scores = self.metric.pairwise_scores_from_ir(
116 hypotheses_ir, centroids, source_ir
117 )
118 if lprobs is not None:
119 lprobs = lprobs.to(pairwise_scores)
120 expected_scores = functional.expectation(pairwise_scores, lprobs=lprobs)
121
122 selector_outputs = self.select(
123 hypotheses, expected_scores, nbest=nbest, source=source
124 )
125 return (
126 self.Output(
127 idx=selector_outputs.idx,
128 sentence=selector_outputs.sentence,
129 score=selector_outputs.score,
130 )
131 | selector_outputs
132 )