Spaces:
Runtime error
Runtime error
import itertools | |
from typing import Dict, Union | |
from nltk import sent_tokenize | |
import nltk | |
nltk.download('punkt') | |
import torch | |
from transformers import( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer | |
) | |
class QAPipeline: | |
def __init__( | |
self | |
): | |
self.model = AutoModelForSeq2SeqLM.from_pretrained("muchad/idt5-qa-qg") | |
self.tokenizer = AutoTokenizer.from_pretrained("muchad/idt5-qa-qg") | |
self.qg_format = "highlight" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(self.device) | |
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration"] | |
self.model_type = "t5" | |
def __call__(self, inputs: str): | |
inputs = " ".join(inputs.split()) | |
answers = self._extract_answers(inputs) | |
flat_answers = list(itertools.chain(*answers)) | |
if len(flat_answers) == 0: | |
return [] | |
def _tokenize(self, | |
inputs, | |
padding=True, | |
truncation=True, | |
add_special_tokens=True, | |
max_length=512 | |
): | |
inputs = self.tokenizer.batch_encode_plus( | |
inputs, | |
max_length=max_length, | |
add_special_tokens=add_special_tokens, | |
truncation=truncation, | |
padding="max_length" if padding else False, | |
pad_to_max_length=padding, | |
return_tensors="pt" | |
) | |
return inputs | |
class TaskPipeline(QAPipeline): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def __call__(self, inputs: Union[Dict, str]): | |
return self._extract_answer(inputs["question"], inputs["context"]) | |
def _prepare_inputs(self, question, context): | |
source_text = f"question: {question} context: {context}" | |
source_text = source_text + " </s>" | |
return source_text | |
def _extract_answer(self, question, context): | |
source_text = self._prepare_inputs(question, context) | |
inputs = self._tokenize([source_text], padding=False) | |
outs = self.model.generate( | |
input_ids=inputs['input_ids'].to(self.device), | |
attention_mask=inputs['attention_mask'].to(self.device), | |
max_length=80, | |
) | |
answer = self.tokenizer.decode(outs[0], skip_special_tokens=True) | |
return answer | |
def pipeline(): | |
task = TaskPipeline | |
return task() |