1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Optional
5
6import torch
7from torch import Tensor
8
9from mbrs import timer
10from mbrs.metrics import Metric, Metrics, get_metric
11from mbrs.selectors import Selector, register
12
13
[docs]
14@register("diverse")
15class SelectorDiverse(Selector):
16 def __init__(self, cfg: SelectorDiverse.Config) -> None:
17 super().__init__(cfg)
18 self.diversity_metric: Metric = get_metric(cfg.diversity_metric)(
19 cfg.diversity_metric_config
20 )
21
[docs]
22 @dataclass
23 class Config(Selector.Config):
24 """Configuration for the selector."""
25
26 diversity_metric: Metrics = Metrics.bleu
27 diversity_metric_config: Metric.Config | None = None
28 diversity_lambda: float = 0.1
29 local_search_iterations: int = 20
30 local_search_neighbors: int = 1
31 seed: int = 0
32
33 def __post_init__(self):
34 if self.diversity_metric_config is None:
35 self.diversity_metric_config = get_metric(
36 self.diversity_metric
37 ).Config()
38
39 cfg: Config
40
[docs]
41 @dataclass
42 class Output(Selector.Output):
43 """
44 - idx (list[int]): Index numbers of the n-best hypotheses.
45 - sentence (list[str]): Sentences of the n-best hypotheses.
46 - score (list[float]): Scores of the n-best hypotheses.
47 - nbest_objective_score: (float): Objective score for the n-best list.
48 - nbest_expected_score (float): Expected score for the n-best list.
49 - nbest_diversity_score: (float): Diversity score for the n-best list.
50 """
51
52 nbest_objective_score: float
53 nbest_expected_score: float
54 nbest_diversity_score: float
55
[docs]
56 @dataclass
57 class Objective:
58 """
59 - score (float): The objective score.
60 - expected_score (float): Expected score for the n-best list.
61 - diversity_score (float): Diversity score for the n-best list.
62 """
63
64 score: float
65 expected_score: float
66 diversity_score: float
67
[docs]
68 def compute_objective(
69 self, expected_scores: Tensor, hypothesis_dissimilarities: Tensor, mask: Tensor
70 ) -> Objective:
71 """Compute the objective function.
72
73 Args:
74 expected_scores (Tensor): The expected scores for each hypothesis. The shape is `(H,)`.
75 hypothesis_dissimilarities (Tensor): The pairwise dissimilarities for all hypotheses. The shape is `(H, H)`.
76 mask (Tensor): Boolean tensor of shape `(H,)`. The positions of True elements are calculated
77 in the objective and the others are discarded.
78
79 Returns:
80 Objective: The objective scores that contain the expected score, diversity score, and the sum of them.
81 """
82 hypothesis_dissimilarities = hypothesis_dissimilarities.clone().float()
83 hypothesis_dissimilarities = hypothesis_dissimilarities.fill_diagonal_(0.0)
84 k = mask.sum().item()
85 expected_score = expected_scores[mask].float().sum() / k
86 hypothesis_dissimilarity = (
87 mask.float() @ hypothesis_dissimilarities @ mask.float() / k / max(k - 1, 1)
88 )
89 objective = (
90 expected_score + self.cfg.diversity_lambda * hypothesis_dissimilarity
91 ).item()
92 return self.Objective(
93 objective, expected_score.item(), hypothesis_dissimilarity.item()
94 )
95
[docs]
96 def search_greedy_best_first(
97 self,
98 expected_scores: Tensor,
99 hypothesis_dissimilarities: Tensor,
100 nbest: int = 1,
101 maximize: bool = True,
102 ) -> Tensor:
103 """Search the solution by greedy best first search.
104
105 Args:
106 expected_scores (Tensor): The expected scores for each hypothesis. The shape is `(H,)`.
107 hypothesis_dissimilarities (Tensor): The pairwise dissimilarities for all hypotheses. The shape is `(H, H)`.
108 nbest (int): The number of final outputs.
109 maximize (bool): Whether maximize the scores or not.
110
111 Returns:
112 Tensor: Boolean tensor of shape `(H,)` where True positions indicate that they are selected.
113 """
114 H = expected_scores.size(0)
115 selections = torch.zeros(H, dtype=torch.bool, device=expected_scores.device)
116 for k in range(nbest):
117 best = float("-inf") if maximize else float("inf")
118 best_i = -1
119 for i in range(H):
120 if selections[i]:
121 continue
122 selection_candidate = selections.clone()
123 selection_candidate[i] = True
124 objective = self.compute_objective(
125 expected_scores, hypothesis_dissimilarities, selection_candidate
126 )
127
128 if self.superior(objective.score, best, maximize=maximize):
129 best = objective.score
130 best_i = i
131 selections[best_i] = True
132
133 return selections
134
[docs]
135 def search_local(
136 self,
137 expected_scores: Tensor,
138 hypothesis_dissimilarities: Tensor,
139 initial_selections: Tensor,
140 nbest: int = 1,
141 maximize: bool = True,
142 ) -> Tensor:
143 """Search the solution by greedy best first search.
144
145 Args:
146 expected_scores (Tensor): The expected scores for each hypothesis. The shape is `(H,)`.
147 hypothesis_dissimilarities (Tensor): The pairwise dissimilarities for all hypotheses. The shape is `(H, H)`.
148 initial_selections (Tensor): Boolean tensor of shape `(H,)` where True positions indicate that they are selected.
149 nbest (int): The number of final outputs.
150 maximize (bool): Whether maximize the scores or not.
151
152 Returns:
153 Tensor: Boolean tensor of shape `(H,)` where True positions indicate that they are selected.
154 """
155 rng = torch.Generator(device=expected_scores.device).manual_seed(self.cfg.seed)
156 H = initial_selections.size(0)
157 selections = initial_selections.clone()
158
159 num_neighbors = min(self.cfg.local_search_neighbors, nbest)
160
161 for i in range(self.cfg.local_search_iterations):
162 prev_selections = selections.clone()
163 selection_indices = selections.nonzero(as_tuple=True)[0]
164 removed_candidates = torch.randperm(
165 nbest, generator=rng, device=rng.device
166 )[:num_neighbors]
167
168 for k in range(num_neighbors):
169 selections[selection_indices[removed_candidates[k]]] = False
170
171 for k in range(num_neighbors):
172 best = float("-inf") if maximize else float("inf")
173 best_i = -1
174 for i in range(H):
175 if selections[i]:
176 continue
177 selection_candidate = selections.clone()
178 selection_candidate[i] = True
179 objective = self.compute_objective(
180 expected_scores, hypothesis_dissimilarities, selection_candidate
181 )
182 if self.superior(objective.score, best, maximize=maximize):
183 best = objective.score
184 best_i = i
185 selections[best_i] = True
186
187 prev_objective = self.compute_objective(
188 expected_scores, hypothesis_dissimilarities, prev_selections
189 )
190 new_objective = self.compute_objective(
191 expected_scores, hypothesis_dissimilarities, selections
192 )
193 if self.superior(
194 prev_objective.score, new_objective.score, maximize=maximize
195 ):
196 selections = prev_selections
197
198 return selections
199
[docs]
200 def select(
201 self,
202 hypotheses: list[str],
203 expected_scores: Tensor,
204 nbest: int = 1,
205 source: Optional[str] = None,
206 maximize: bool = True,
207 **kwargs,
208 ) -> SelectorDiverse.Output:
209 """Select the final output list.
210
211 Args:
212 hypotheses (list[str]): Hypotheses.
213 expected_scores (Tensor): The expected scores for each hypothesis.
214 nbest (int): Return the n-best hypotheses based on the selection rule.
215 source (str, optional): A source.
216 maximize (bool): Whether maximize the scores or not.
217
218 Returns:
219 Selector.Output: Selected hypotheses.
220 """
221 nbest = min(len(hypotheses), nbest)
222 with timer.measure("dissimilarity_calculation"):
223 hypothesis_dissimilarities = self.diversity_metric.pairwise_scores(
224 hypotheses,
225 hypotheses,
226 source=source,
227 ).to(expected_scores)
228 if maximize:
229 hypothesis_dissimilarities *= -1
230 with timer.measure("search/greedy_best_first"):
231 selections = self.search_greedy_best_first(
232 expected_scores,
233 hypothesis_dissimilarities,
234 nbest=nbest,
235 maximize=maximize,
236 )
237 with timer.measure("search/local"):
238 selections = self.search_local(
239 expected_scores,
240 hypothesis_dissimilarities,
241 selections,
242 nbest=nbest,
243 maximize=maximize,
244 )
245 objective = self.compute_objective(
246 expected_scores, hypothesis_dissimilarities, selections
247 )
248 topk_scores, topk_order = self.topk(
249 expected_scores[selections].float(), k=nbest, maximize=maximize
250 )
251 selected_idx_set = [
252 i
253 for i, selection in zip(range(len(hypotheses)), selections.tolist())
254 if selection
255 ]
256 topk_indices = [selected_idx_set[i] for i in topk_order]
257 return self.Output(
258 idx=topk_indices,
259 sentence=[hypotheses[idx] for idx in topk_indices],
260 score=topk_scores,
261 nbest_objective_score=objective.score,
262 nbest_expected_score=objective.expected_score,
263 nbest_diversity_score=objective.diversity_score,
264 )