How to define a new decoder#
Examples#
This tutorial explains how to define a new decoder. The below example implements the naive MBR decoding and extends the output object to return other features.
Inherit an abstract class defined in
mbrs.decoders.base.DecoderReferenceBasedis mainly used for MBR decoding that returns the N most probable hypotheses using sets of hypotheses and pseudo-references.
from mbrs.decoders.base import DecoderReferenceBased class DecoderMBRWithAllScores(DecoderReferenceBased): """Naive MBR decoder class."""
Define the configuration dataclass if you need to add options.
Configuration dataclass
DecoderMBRWithAllScores.Configshould inherit that of the parent class for consistency.
from dataclasses import dataclass from mbrs.decoders.base import DecoderReferenceBased class DecoderMBRWithAllScores(DecoderReferenceBased): """Naive MBR decoder class.""" @dataclass class Config(DecoderMBRWithAllScores.Config): """Naive MBR decoder configuration.""" sort_scores: bool = False
Child classes of
DecoderReferenceBasedrequires to implement thedecode()method.from dataclasses import dataclass from typing import Optional from mbrs.decoders.base import DecoderReferenceBased class DecoderMBRWithAllScores(DecoderReferenceBased): """Naive MBR decoder class.""" @dataclass class Config(DecoderMBRWithAllScores.Config): """Naive MBR decoder configuration.""" sort_scores: bool = False def decode( self, hypotheses: list[str], references: list[str], source: Optional[str] = None, nbest: int = 1, reference_lprobs: Optional[Tensor] = None, ) -> DecoderMBRWithAllScores.Output: expected_scores = self.metric.expected_scores( hypotheses, references, source, reference_lprobs=reference_lprobs ) topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) return self.Output( idx=topk_indices, sentence=[hypotheses[idx] for idx in topk_indices], score=topk_scores, )
In this example, we extend the output dataclass to include all expected scores.
DecoderMBRWithAllScores.Outputneeds to inherit the parent output dataclass.
from dataclasses import dataclass from typing import Optional from torch import Tensor from mbrs.decoders.base import DecoderReferenceBased class DecoderMBRWithAllScores(DecoderReferenceBased): """Naive MBR decoder class.""" @dataclass class Config(DecoderMBRWithAllScores.Config): sort_scores: bool = False @dataclass class Output(DecoderReferenceBased.Output): all_scores: Optional[Tensor] = None def decode( self, hypotheses: list[str], references: list[str], source: Optional[str] = None, nbest: int = 1, reference_lprobs: Optional[Tensor] = None, ) -> DecoderMBRWithAllScores.Output: expected_scores = self.metric.expected_scores( hypotheses, references, source, reference_lprobs=reference_lprobs ) topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) if self.cfg.sort_scores: all_scores = expected_scores.sort(dim=-1, descending=self.metric.HIGH_IS_BETTER) else: all_scores = expected_scores return self.Output( idx=topk_indices, sentence=[hypotheses[idx] for idx in topk_indices], score=topk_scores, all_scores=all_scores, )
Finally, register the class to be called from CLI.
Just add
@register("mbr_with_all_scores")to the class definition.
from dataclasses import dataclass from typing import Optional from torch import Tensor from mbrs.decoders.base import DecoderReferenceBased, register @register("mbr_with_all_scores") class DecoderMBRWithAllScores(DecoderReferenceBased): """Naive MBR decoder class.""" @dataclass class Config(DecoderMBRWithAllScores.Config): sort_scores: bool = False @dataclass class Output(DecoderReferenceBased.Output): all_scores: Optional[Tensor] = None def decode( self, hypotheses: list[str], references: list[str], source: Optional[str] = None, nbest: int = 1, reference_lprobs: Optional[Tensor] = None, ) -> DecoderMBRWithAllScores.Output: expected_scores = self.metric.expected_scores( hypotheses, references, source, reference_lprobs=reference_lprobs ) topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) if self.cfg.sort_scores: all_scores = expected_scores.sort(dim=-1, descending=self.metric.HIGH_IS_BETTER) else: all_scores = expected_scores return self.Output( idx=topk_indices, sentence=[hypotheses[idx] for idx in topk_indices], score=topk_scores, all_scores=all_scores, )
Note
All methods should have the same types for both inputs and outputs as the base class.