Source code for mbrs.selectors.diverse

  1from __future__ import annotations
  2
  3from dataclasses import dataclass
  4from typing import Optional
  5
  6import torch
  7from torch import Tensor
  8
  9from mbrs import timer
 10from mbrs.metrics import Metric, Metrics, get_metric
 11from mbrs.selectors import Selector, register
 12
 13
[docs] 14@register("diverse") 15class SelectorDiverse(Selector): 16 def __init__(self, cfg: SelectorDiverse.Config) -> None: 17 super().__init__(cfg) 18 self.diversity_metric: Metric = get_metric(cfg.diversity_metric)( 19 cfg.diversity_metric_config 20 ) 21
[docs] 22 @dataclass 23 class Config(Selector.Config): 24 """Configuration for the selector.""" 25 26 diversity_metric: Metrics = Metrics.bleu 27 diversity_metric_config: Metric.Config | None = None 28 diversity_lambda: float = 0.1 29 local_search_iterations: int = 20 30 local_search_neighbors: int = 1 31 seed: int = 0 32 33 def __post_init__(self): 34 if self.diversity_metric_config is None: 35 self.diversity_metric_config = get_metric( 36 self.diversity_metric 37 ).Config()
38 39 cfg: Config 40
[docs] 41 @dataclass 42 class Output(Selector.Output): 43 """ 44 - idx (list[int]): Index numbers of the n-best hypotheses. 45 - sentence (list[str]): Sentences of the n-best hypotheses. 46 - score (list[float]): Scores of the n-best hypotheses. 47 - nbest_objective_score: (float): Objective score for the n-best list. 48 - nbest_expected_score (float): Expected score for the n-best list. 49 - nbest_diversity_score: (float): Diversity score for the n-best list. 50 """ 51 52 nbest_objective_score: float 53 nbest_expected_score: float 54 nbest_diversity_score: float
55
[docs] 56 @dataclass 57 class Objective: 58 """ 59 - score (float): The objective score. 60 - expected_score (float): Expected score for the n-best list. 61 - diversity_score (float): Diversity score for the n-best list. 62 """ 63 64 score: float 65 expected_score: float 66 diversity_score: float
67
[docs] 68 def compute_objective( 69 self, expected_scores: Tensor, hypothesis_dissimilarities: Tensor, mask: Tensor 70 ) -> Objective: 71 """Compute the objective function. 72 73 Args: 74 expected_scores (Tensor): The expected scores for each hypothesis. The shape is `(H,)`. 75 hypothesis_dissimilarities (Tensor): The pairwise dissimilarities for all hypotheses. The shape is `(H, H)`. 76 mask (Tensor): Boolean tensor of shape `(H,)`. The positions of True elements are calculated 77 in the objective and the others are discarded. 78 79 Returns: 80 Objective: The objective scores that contain the expected score, diversity score, and the sum of them. 81 """ 82 hypothesis_dissimilarities = hypothesis_dissimilarities.clone().float() 83 hypothesis_dissimilarities = hypothesis_dissimilarities.fill_diagonal_(0.0) 84 k = mask.sum().item() 85 expected_score = expected_scores[mask].float().sum() / k 86 hypothesis_dissimilarity = ( 87 mask.float() @ hypothesis_dissimilarities @ mask.float() / k / max(k - 1, 1) 88 ) 89 objective = ( 90 expected_score + self.cfg.diversity_lambda * hypothesis_dissimilarity 91 ).item() 92 return self.Objective( 93 objective, expected_score.item(), hypothesis_dissimilarity.item() 94 )
95
[docs] 96 def search_greedy_best_first( 97 self, 98 expected_scores: Tensor, 99 hypothesis_dissimilarities: Tensor, 100 nbest: int = 1, 101 maximize: bool = True, 102 ) -> Tensor: 103 """Search the solution by greedy best first search. 104 105 Args: 106 expected_scores (Tensor): The expected scores for each hypothesis. The shape is `(H,)`. 107 hypothesis_dissimilarities (Tensor): The pairwise dissimilarities for all hypotheses. The shape is `(H, H)`. 108 nbest (int): The number of final outputs. 109 maximize (bool): Whether maximize the scores or not. 110 111 Returns: 112 Tensor: Boolean tensor of shape `(H,)` where True positions indicate that they are selected. 113 """ 114 H = expected_scores.size(0) 115 selections = torch.zeros(H, dtype=torch.bool, device=expected_scores.device) 116 for k in range(nbest): 117 best = float("-inf") if maximize else float("inf") 118 best_i = -1 119 for i in range(H): 120 if selections[i]: 121 continue 122 selection_candidate = selections.clone() 123 selection_candidate[i] = True 124 objective = self.compute_objective( 125 expected_scores, hypothesis_dissimilarities, selection_candidate 126 ) 127 128 if self.superior(objective.score, best, maximize=maximize): 129 best = objective.score 130 best_i = i 131 selections[best_i] = True 132 133 return selections
134
[docs] 135 def search_local( 136 self, 137 expected_scores: Tensor, 138 hypothesis_dissimilarities: Tensor, 139 initial_selections: Tensor, 140 nbest: int = 1, 141 maximize: bool = True, 142 ) -> Tensor: 143 """Search the solution by greedy best first search. 144 145 Args: 146 expected_scores (Tensor): The expected scores for each hypothesis. The shape is `(H,)`. 147 hypothesis_dissimilarities (Tensor): The pairwise dissimilarities for all hypotheses. The shape is `(H, H)`. 148 initial_selections (Tensor): Boolean tensor of shape `(H,)` where True positions indicate that they are selected. 149 nbest (int): The number of final outputs. 150 maximize (bool): Whether maximize the scores or not. 151 152 Returns: 153 Tensor: Boolean tensor of shape `(H,)` where True positions indicate that they are selected. 154 """ 155 rng = torch.Generator(device=expected_scores.device).manual_seed(self.cfg.seed) 156 H = initial_selections.size(0) 157 selections = initial_selections.clone() 158 159 num_neighbors = min(self.cfg.local_search_neighbors, nbest) 160 161 for i in range(self.cfg.local_search_iterations): 162 prev_selections = selections.clone() 163 selection_indices = selections.nonzero(as_tuple=True)[0] 164 removed_candidates = torch.randperm( 165 nbest, generator=rng, device=rng.device 166 )[:num_neighbors] 167 168 for k in range(num_neighbors): 169 selections[selection_indices[removed_candidates[k]]] = False 170 171 for k in range(num_neighbors): 172 best = float("-inf") if maximize else float("inf") 173 best_i = -1 174 for i in range(H): 175 if selections[i]: 176 continue 177 selection_candidate = selections.clone() 178 selection_candidate[i] = True 179 objective = self.compute_objective( 180 expected_scores, hypothesis_dissimilarities, selection_candidate 181 ) 182 if self.superior(objective.score, best, maximize=maximize): 183 best = objective.score 184 best_i = i 185 selections[best_i] = True 186 187 prev_objective = self.compute_objective( 188 expected_scores, hypothesis_dissimilarities, prev_selections 189 ) 190 new_objective = self.compute_objective( 191 expected_scores, hypothesis_dissimilarities, selections 192 ) 193 if self.superior( 194 prev_objective.score, new_objective.score, maximize=maximize 195 ): 196 selections = prev_selections 197 198 return selections
199
[docs] 200 def select( 201 self, 202 hypotheses: list[str], 203 expected_scores: Tensor, 204 nbest: int = 1, 205 source: Optional[str] = None, 206 maximize: bool = True, 207 **kwargs, 208 ) -> SelectorDiverse.Output: 209 """Select the final output list. 210 211 Args: 212 hypotheses (list[str]): Hypotheses. 213 expected_scores (Tensor): The expected scores for each hypothesis. 214 nbest (int): Return the n-best hypotheses based on the selection rule. 215 source (str, optional): A source. 216 maximize (bool): Whether maximize the scores or not. 217 218 Returns: 219 Selector.Output: Selected hypotheses. 220 """ 221 nbest = min(len(hypotheses), nbest) 222 with timer.measure("dissimilarity_calculation"): 223 hypothesis_dissimilarities = self.diversity_metric.pairwise_scores( 224 hypotheses, 225 hypotheses, 226 source=source, 227 ).to(expected_scores) 228 if maximize: 229 hypothesis_dissimilarities *= -1 230 with timer.measure("search/greedy_best_first"): 231 selections = self.search_greedy_best_first( 232 expected_scores, 233 hypothesis_dissimilarities, 234 nbest=nbest, 235 maximize=maximize, 236 ) 237 with timer.measure("search/local"): 238 selections = self.search_local( 239 expected_scores, 240 hypothesis_dissimilarities, 241 selections, 242 nbest=nbest, 243 maximize=maximize, 244 ) 245 objective = self.compute_objective( 246 expected_scores, hypothesis_dissimilarities, selections 247 ) 248 topk_scores, topk_order = self.topk( 249 expected_scores[selections].float(), k=nbest, maximize=maximize 250 ) 251 selected_idx_set = [ 252 i 253 for i, selection in zip(range(len(hypotheses)), selections.tolist()) 254 if selection 255 ] 256 topk_indices = [selected_idx_set[i] for i in topk_order] 257 return self.Output( 258 idx=topk_indices, 259 sentence=[hypotheses[idx] for idx in topk_indices], 260 score=topk_scores, 261 nbest_objective_score=objective.score, 262 nbest_expected_score=objective.expected_score, 263 nbest_diversity_score=objective.diversity_score, 264 )