How to define a new metric

Contents

How to define a new metric#

Examples#

This tutorial explains how to define a new metric using an example of MetricTER.

  1. Inherit an abstract class defined in mbrs.metrics.base.

    • Metric calculates the score between a hypothesis and reference with optionally using a source.

    • If the lower score means better, the class variable HIGHER_IS_BETTER should be set to False.

    from mbrs.metrics.base import Metric
    
    
    class MetricTER(Metric):
        """TER metric class."""
    
        HIGHER_IS_BETTER: bool = False
    
  2. Define the configuration dataclass.

    • Configuration dataclass MetricTER.Config should inherit that of the parent class for consistency.

    • __init__() receives an instance of configuration dataclass cfg: MetricTER.Config and setup the scorer function.

    from dataclasses import dataclass
    
    from sacrebleu.metrics.ter import TER
    
    from mbrs.metrics.base import Metric
    
    
    class MetricTER(Metric):
        """TER metric class."""
    
        HIGHER_IS_BETTER: bool = False
    
        @dataclass
        class Config(Metric.Config):
            """TER metric configuration."""
    
            normalized: bool = False
            no_punct: bool = False
            asian_support: bool = False
            case_sensitive: bool = False
    
        def __init__(self, cfg: MetricTER.Config):
            self.scorer = TER(
                normalized=cfg.normalized,
                no_punct=cfg.no_punct,
                asian_support=cfg.asian_support,
                case_sensitive=cfg.case_sensitive,
            )
    
  3. Child classes of Metric requires to implement the score() method which calculates a score of a single example.

    • In the default, score() is called iteratively in the MBR decoding.

    • If the metric can compute pairwise scores between hypotheses and pseudo-references in parallel, it would be better to override pairwise_scores() to allow batch computation.

    from dataclasses import dataclass
    
    from sacrebleu.metrics.ter import TER
    
    from mbrs.metrics.base import Metric
    
    
    class MetricTER(Metric):
        """TER metric class."""
    
        HIGHER_IS_BETTER: bool = False
    
        @dataclass
        class Config(Metric.Config):
            """TER metric configuration."""
    
            normalized: bool = False
            no_punct: bool = False
            asian_support: bool = False
            case_sensitive: bool = False
    
        def __init__(self, cfg: MetricTER.Config):
            self.scorer = TER(
                normalized=cfg.normalized,
                no_punct=cfg.no_punct,
                asian_support=cfg.asian_support,
                case_sensitive=cfg.case_sensitive,
            )
    
        def score(self, hypothesis: str, reference: str, *_) -> float:
            return self.scorer.sentence_score(hypothesis, [reference]).score
    
  4. Register the class to be called from CLI.

    • Just add @register("ter") to the class definition.

    from dataclasses import dataclass
    
    from sacrebleu.metrics.ter import TER
    
    from mbrs.metrics.base import Metric, register
    
    
    @register("ter")
    class MetricTER(Metric):
        """TER metric class."""
    
        HIGHER_IS_BETTER: bool = False
    
        @dataclass
        class Config(Metric.Config):
            """TER metric configuration."""
    
            normalized: bool = False
            no_punct: bool = False
            asian_support: bool = False
            case_sensitive: bool = False
    
        def __init__(self, cfg: MetricTER.Config):
            self.scorer = TER(
                normalized=cfg.normalized,
                no_punct=cfg.no_punct,
                asian_support=cfg.asian_support,
                case_sensitive=cfg.case_sensitive,
            )
    
        def score(self, hypothesis: str, reference: str, *_) -> float:
            return self.scorer.sentence_score(hypothesis, [reference]).score
    

Note

All methods should have the same types for both inputs and outputs as the base class.