Spaces:
Runtime error
Runtime error
""" | |
CoLA for Grammaticality | |
-------------------------- | |
""" | |
import lru | |
import nltk | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from textattack.constraints import Constraint | |
from textattack.models.wrappers import HuggingFaceModelWrapper | |
class COLA(Constraint): | |
"""Constrains an attack to text that has a similar number of linguistically | |
accecptable sentences as the original text. Linguistic acceptability is | |
determined by a model pre-trained on the `CoLA dataset <https://nyu- | |
mll.github.io/CoLA/>`_. By default a BERT model is used, see the `pre- | |
trained models README <https://github.com/QData/TextAttack/tree/master/ | |
textattack/models>`_ for a full list of available models or provide your | |
own model from the huggingface model hub. | |
Args: | |
max_diff (float or int): The absolute (if int or greater than or equal to 1) or percent (if float and less than 1) | |
maximum difference allowed between the number of valid sentences in the reference | |
text and the number of valid sentences in the attacked text. | |
model_name (str): The name of the pre-trained model to use for classification. The model must be in huggingface model hub. | |
compare_against_original (bool): If `True`, compare against the original text. | |
Otherwise, compare against the most recent text. | |
""" | |
def __init__( | |
self, | |
max_diff, | |
model_name="textattack/bert-base-uncased-CoLA", | |
compare_against_original=True, | |
): | |
super().__init__(compare_against_original) | |
if not isinstance(max_diff, float) and not isinstance(max_diff, int): | |
raise TypeError("max_diff must be a float or int") | |
if max_diff < 0.0: | |
raise ValueError("max_diff must be a value greater or equal to than 0.0") | |
self.max_diff = max_diff | |
self.model_name = model_name | |
self._reference_score_cache = lru.LRU(2**10) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = HuggingFaceModelWrapper(model, tokenizer) | |
def clear_cache(self): | |
self._reference_score_cache.clear() | |
def _check_constraint(self, transformed_text, reference_text): | |
if reference_text not in self._reference_score_cache: | |
# Split the text into sentences before predicting validity | |
reference_sentences = nltk.sent_tokenize(reference_text.text) | |
# A label of 1 indicates the sentence is valid | |
num_valid = self.model(reference_sentences).argmax(axis=1).sum() | |
self._reference_score_cache[reference_text] = num_valid | |
sentences = nltk.sent_tokenize(transformed_text.text) | |
predictions = self.model(sentences) | |
num_valid = predictions.argmax(axis=1).sum() | |
reference_score = self._reference_score_cache[reference_text] | |
if isinstance(self.max_diff, int) or self.max_diff >= 1: | |
threshold = reference_score - self.max_diff | |
else: | |
threshold = reference_score - (reference_score * self.max_diff) | |
if num_valid < threshold: | |
return False | |
return True | |
def extra_repr_keys(self): | |
return [ | |
"max_diff", | |
"model_name", | |
] + super().extra_repr_keys() | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
state["_reference_score_cache"] = self._reference_score_cache.get_size() | |
return state | |
def __setstate__(self, state): | |
self.__dict__ = state | |
self._reference_score_cache = lru.LRU(state["_reference_score_cache"]) | |