anonymous8/RPD-Demo
initial commit
4943752
raw
history blame
747 Bytes
from textattack.models.wrappers import HuggingFaceModelWrapper
class TADModelWrapper(HuggingFaceModelWrapper):
""" Transformers sentiment analysis pipeline returns a list of responses
like
[{'label': 'POSITIVE', 'score': 0.7817379832267761}]
We need to convert that to a format TextAttack understands, like
[[0.218262017, 0.7817379832267761]
"""
def __init__(self, model):
self.model = model # pipeline = pipeline
def __call__(self, text_inputs, **kwargs):
outputs = []
for text_input in text_inputs:
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
outputs.append(raw_outputs['probs'])
return outputs