Source code for mbrs.selectors.base

  1from __future__ import annotations
  2
  3import abc
  4from dataclasses import dataclass
  5from typing import Optional
  6
  7import torch
  8from torch import Tensor
  9
 10from mbrs import registry
 11
 12
[docs] 13class Selector(abc.ABC): 14 """Selector base class.""" 15 16 def __init__(self, cfg: Selector.Config) -> None: 17 self.cfg = cfg 18
[docs] 19 @dataclass 20 class Config: 21 """Configuration for the selector."""
22
[docs] 23 @dataclass 24 class Output: 25 """ 26 - idx (list[int]): Index numbers of the n-best hypotheses. 27 - sentence (list[str]): Sentences of the n-best hypotheses. 28 - score (list[float]): Scores of the n-best hypotheses. 29 """ 30 31 idx: list[int] 32 sentence: list[str] 33 score: list[float]
34
[docs] 35 def topk( 36 self, x: Tensor, k: int = 1, maximize: bool = True 37 ) -> tuple[list[float], list[int]]: 38 """Return the top-k best elements and corresponding indices. 39 40 Args: 41 x (Tensor): Input 1-D array. 42 k (int): Return the top-k values and indices. 43 maximize (bool): Whether maximize the scores or not. 44 45 Returns: 46 tuple[list[float], list[int]] 47 - list[float]: The top-k values. 48 - list[int]: The top-k indices. 49 """ 50 values, indices = torch.topk(x, k=min(k, len(x)), largest=maximize) 51 return values.tolist(), indices.tolist()
52
[docs] 53 def argbest(self, x: Tensor, maximize: bool = True) -> Tensor: 54 """Return the index of the best element. 55 56 Args: 57 x (Tensor): Input 1-D array. 58 maximize (bool): Whether maximize the scores or not. 59 60 Returns: 61 Tensor: A scalar tensor of the best index. 62 """ 63 if maximize: 64 return torch.argmax(x) 65 return torch.argmin(x)
66
[docs] 67 def superior(self, a: float, b: float, maximize: bool = True) -> bool: 68 """Return whether the score `a` is superior to the score `b`. 69 70 Args: 71 a (float): A score. 72 b (float): A score. 73 maximize (bool): Whether maximize the scores or not. 74 75 Returns: 76 bool: Return True when `a` is superior to `b`. 77 """ 78 if maximize: 79 return a > b 80 return a < b
81
[docs] 82 @abc.abstractmethod 83 def select( 84 self, 85 hypotheses: list[str], 86 expected_scores: Tensor, 87 nbest: int = 1, 88 source: Optional[str] = None, 89 maximize: bool = True, 90 **kwargs, 91 ) -> Selector.Output: 92 """Select the final output list. 93 94 Args: 95 hypotheses (list[str]): Hypotheses. 96 expected_scores (Tensor): The expected scores for each hypothesis. 97 nbest (int): Return the n-best hypotheses based on the selection rule. 98 source (str, optional): A source. 99 maximize (bool): Whether maximize the scores or not. 100 101 Returns: 102 Selector.Output: Selected hypotheses. 103 """
104 105 106register, get_selector = registry.Registry(Selector).get_closure()