Source code for textattack.goal_functions.text.minimize_bleu

Goal Function for Attempts to minimize the BLEU score


import functools

import nltk

import textattack

from .text_to_text_goal_function import TextToTextGoalFunction

[docs]class MinimizeBleu(TextToTextGoalFunction): """Attempts to minimize the BLEU score between the current output translation and the reference translation. BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation). `ArxivURL`_ .. _ArxivURL: This goal function is defined in (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations). `ArxivURL2`_ .. _ArxivURL2: """ EPS = 1e-10 def __init__(self, *args, target_bleu=0.0, **kwargs): self.target_bleu = target_bleu super().__init__(*args, **kwargs)
[docs] def clear_cache(self): if self.use_cache: self._call_model_cache.clear() get_bleu.cache_clear()
def _is_goal_complete(self, model_output, _): bleu_score = 1.0 - self._get_score(model_output, _) return bleu_score <= (self.target_bleu + MinimizeBleu.EPS) def _get_score(self, model_output, _): model_output_at = textattack.shared.AttackedText(model_output) ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output) bleu_score = get_bleu(model_output_at, ground_truth_at) return 1.0 - bleu_score
[docs] def extra_repr_keys(self): if self.maximizable: return ["maximizable"] else: return ["maximizable", "target_bleu"]
[docs]@functools.lru_cache(maxsize=2**12) def get_bleu(a, b): ref = a.words hyp = b.words bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp) return bleu_score