Source code for simuleval.evaluator.scorers.quality_scorer

import logging
import sacrebleu
from pathlib import Path
from typing import List, Dict

QUALITY_SCORERS_DICT = {}


def register_quality_scorer(name):
    def register(cls):
        QUALITY_SCORERS_DICT[name] = cls
        return cls

    return register


class QualityScorer:
    def __init__(self) -> None:
        pass

    def __call__(self, references, translations) -> float:
        raise NotImplementedError

    @staticmethod
    def add_args(parser):
        pass


def add_sacrebleu_args(parser):
    parser.add_argument("--sacrebleu-tokenizer", type=str, default="13a", help="Tokenizer in sacrebleu")


[docs]@register_quality_scorer("BLEU") class SacreBLEUScorer(QualityScorer): """ SacreBLEU Scorer Usage: :code:`--quality-metrics BLEU` Additional command line arguments: .. argparse:: :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args :passparser: :prog: """ def __init__(self, tokenizer: str = "13a") -> None: super().__init__() self.logger = logging.getLogger("simuleval.scorer.bleu") self.tokenizer = tokenizer def __call__(self, instances: Dict) -> float: try: return sacrebleu.corpus_bleu( [ins.prediction for ins in instances.values()], [[ins.reference for ins in instances.values()]], tokenize=self.tokenizer, ).score except: return 0 @staticmethod def add_args(parser): add_sacrebleu_args(parser) @classmethod def from_args(cls, args): return cls(args.sacrebleu_tokenizer)
[docs]@register_quality_scorer("ASR_BLEU") class ASRSacreBLEUScorer(QualityScorer): """ ASR + SacreBLEU Scorer (BETA version) Usage: :code:`--quality-metrics ASR_BLEU` Additional command line arguments: .. argparse:: :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args :passparser: :prog: """ def __init__(self, tokenizer: str = "13a") -> None: super().__init__() self.logger = logging.getLogger("simuleval.scorer.asr_bleu") self.tokenizer = "13a" # todo make it configurable def __call__(self, instances) -> float: return sacrebleu.corpus_bleu( self.asr_transcribe(instances), [[ins.reference for ins in instances.values()]], tokenize=self.tokenizer, ).score def asr_transcribe(self, instances): self.logger.warn("Beta feature: Evaluating speech output") try: from ust_common.evaluation import prepare_w2v_audio_finetuning_data from ust_common.evaluation import fairseq_w2v_ctc_infer except: self.logger.warn("Please install ust_common.") return ["" for _ in range(len(self))] from simuleval.utils.fairseq import load_fairseq_manifest wav_dir = Path(instances[0].prediction).absolute().parent root_dir = wav_dir.parent # TODO make it configurable if not (root_dir / "asr_prep_data").exists(): prepare_w2v_audio_finetuning_data( wav_dir, root_dir / "asr_prep_data", output_subset_name="eval", waveform_filename_pattern="*", ) if not (root_dir / "asr_out").exists(): fairseq_w2v_ctc_infer( root_dir / "asr_prep_data", "/checkpoint/annl/s2st/eval/asr/model/wav2vec2/wav2vec_vox_960h_pl.pt", "eval", root_dir / "asr_out", ) translations_w_id = load_fairseq_manifest( root_dir / "asr_out" / "eval_asr_predictions.tsv" ) translations_w_id = sorted( translations_w_id, key=lambda x: int(x["id"].split("_")[-1]) ) translation_list = [] for idx, item in enumerate(translations_w_id): with open(wav_dir / f"{idx}_pred.txt", "w") as f: f.write(item["transcription"].lower() + "\n") translation_list.append(item["transcription"].lower()) return translation_list @staticmethod def add_args(parser): add_sacrebleu_args(parser) @classmethod def from_args(cls, args): return cls(args.sacrebleu_tokenizer)