mbrs.metrics.metricx module#
- class mbrs.metrics.metricx.MT5ForRegression(config: MT5Config)[source]#
Bases:
MT5PreTrainedModelMT5 model for regression.
This implementation is copied from google-research/metricx
- config_class#
alias of
MT5Config
- forward(input_ids: LongTensor | None = None, attention_mask: FloatTensor | None = None, decoder_attention_mask: BoolTensor | None = None, head_mask: FloatTensor | None = None, decoder_head_mask: FloatTensor | None = None, cross_attn_head_mask: Tensor | None = None, encoder_outputs: tuple[tuple[Tensor]] | None = None, past_key_values: tuple[tuple[Tensor]] | None = None, inputs_embeds: FloatTensor | None = None, decoder_inputs_embeds: FloatTensor | None = None, labels: FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None) tuple[Tensor] | MT5ForRegressionOutput[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mbrs.metrics.metricx.MT5ForRegressionOutput(loss: 'Optional[torch.Tensor]' = None, predictions: 'Optional[torch.Tensor]' = None)[source]#
Bases:
ModelOutput
- class mbrs.metrics.metricx.MetricMetricX(cfg: Config)[source]#
Bases:
MetricMetricX metric class.
References: - MetricX-23: https://aclanthology.org/2023.wmt-1.63 - MetricX-24: https://aclanthology.org/2024.wmt-1.35
Available checkpoints:
google/metricx-24-hybrid-xxl-v2p6
google/metricx-24-hybrid-xl-v2p6
google/metricx-24-hybrid-large-v2p6
google/metricx-23-xxl-v2p0
google/metricx-23-xl-v2p0
google/metricx-23-large-v2p0
google/metricx-23-qe-xxl-v2p0
google/metricx-23-qe-xl-v2p0
google/metricx-23-qe-large-v2p0
- class Config(model: str = 'google/metricx-24-hybrid-xxl-v2p6', batch_size: int = 8, fp16: bool = False, bf16: bool = False, cpu: bool = False)[source]#
Bases:
ConfigMetricX metric configuration.
model (str): Model name or path.
batch_size (int): Batch size.
fp16 (bool): Use float16 for the forward computation.
bf16 (bool): Use bfloat16 for the forward computation.
cpu (bool): Use CPU for the forward computation.
- METRICX23_QE_MODELS = {'google/metricx-23-qe-large-v2p0', 'google/metricx-23-qe-xl-v2p0', 'google/metricx-23-qe-xxl-v2p0'}#
- METRICX_INPUT_LENGTH_MAP = {MetricXVersion.metricx_23: 1024, MetricXVersion.metricx_24: 1536}#
- METRICX_INPUT_PREFIX_MAP = {MetricXVersion.metricx_23: MetricMetricX.InputPrefix(hypothesis='candidate: ', reference=' reference: ', source=' source: '), MetricXVersion.metricx_24: MetricMetricX.InputPrefix(hypothesis=' candidate: ', reference=' reference: ', source='source: ')}#
- METRICX_VERSION_MAP = {'google/metricx-23-large-v2p0': MetricXVersion.metricx_23, 'google/metricx-23-qe-large-v2p0': MetricXVersion.metricx_23, 'google/metricx-23-qe-xl-v2p0': MetricXVersion.metricx_23, 'google/metricx-23-qe-xxl-v2p0': MetricXVersion.metricx_23, 'google/metricx-23-xl-v2p0': MetricXVersion.metricx_23, 'google/metricx-23-xxl-v2p0': MetricXVersion.metricx_23, 'google/metricx-24-hybrid-large-v2p6': MetricXVersion.metricx_24, 'google/metricx-24-hybrid-xl-v2p6': MetricXVersion.metricx_24, 'google/metricx-24-hybrid-xxl-v2p6': MetricXVersion.metricx_24}#
- class MetricXVersion(value)[source]#
-
An enumeration.
- metricx_23 = 'metricx_23'#
- metricx_24 = 'metricx_24'#
- corpus_score(hypotheses: list[str], references_lists: list[list[str]] | None = None, sources: list[str] | None = None) float[source]#
Calculate the corpus-level score.
- property device: device#
Returns the device of the model.
- pairwise_scores(hypotheses: list[str], references: list[str], source: str | None = None) Tensor[source]#
Calculate the pairwise scores.
- score(hypothesis: str, reference: str | None = None, source: str | None = None) float[source]#
Calculate the score of the given hypothesis.
- scorer: MT5ForRegression#