Aldo07's picture
Add project
e54728d verified
import itertools
from typing import Dict, Union
from nltk import sent_tokenize
import nltk
nltk.download('punkt')
import torch
from transformers import(
AutoModelForSeq2SeqLM,
AutoTokenizer
)
class QAPipeline:
def __init__(
self
):
self.model = AutoModelForSeq2SeqLM.from_pretrained("muchad/idt5-qa-qg")
self.tokenizer = AutoTokenizer.from_pretrained("muchad/idt5-qa-qg")
self.qg_format = "highlight"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration"]
self.model_type = "t5"
def __call__(self, inputs: str):
inputs = " ".join(inputs.split())
answers = self._extract_answers(inputs)
flat_answers = list(itertools.chain(*answers))
if len(flat_answers) == 0:
return []
def _tokenize(self,
inputs,
padding=True,
truncation=True,
add_special_tokens=True,
max_length=512
):
inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=max_length,
add_special_tokens=add_special_tokens,
truncation=truncation,
padding="max_length" if padding else False,
pad_to_max_length=padding,
return_tensors="pt"
)
return inputs
class TaskPipeline(QAPipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, inputs: Union[Dict, str]):
return self._extract_answer(inputs["question"], inputs["context"])
def _prepare_inputs(self, question, context):
source_text = f"question: {question} context: {context}"
source_text = source_text + " </s>"
return source_text
def _extract_answer(self, question, context):
source_text = self._prepare_inputs(question, context)
inputs = self._tokenize([source_text], padding=False)
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=80,
)
answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
return answer
def pipeline():
task = TaskPipeline
return task()