QA_Albert / app.py
abhilash1910's picture
Initial Commit
079a6ae
raw
history blame
2.24 kB
import gradio as gr
from transformers import AutoTokenizer,AutoModelForQuestionAnswering
import torch
def inference(question,context):
question_first=bool(tokenizer.padding_side=='right')
max_answer_len=5
encoded_text=tokenizer.encode_plus(question,context,padding='longest',
truncation="longest_first" ,
max_length=512,
stride=30,
return_tensors="pt",
return_token_type_ids=False,
return_overflowing_tokens=False,
return_offsets_mapping=False,
return_special_tokens_mask=False)
input_ids=encoded_text['input_ids'].tolist()[0]
tokens=tokenizer.convert_ids_to_tokens(input_ids)
with torch.no_grad():
outputs=model(**encoded_text)
# answer_st=outputs.start_logits
# answer_et=outputs.end_logits
start_,end_=outputs[:2]
answer_start=torch.argmax(start_)
answer_end=torch.argmax(end_)+1
answer=tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
return answer
model=AutoModelForQuestionAnswering.from_pretrained('abhilash1910/albert-squad-v2')
tokenizer=AutoTokenizer.from_pretrained('abhilash1910/albert-squad-v2')
'''
nlp_QA=pipeline('question-answering',model=model,tokenizer=tokenizer)
QA_inp={
'question': 'How many parameters does Bert large have?',
'context': 'Bert large is really big... it has 24 layers, for a total of 340M parameters.Altogether it is 1.34 GB so expect it to take a couple minutes to download to your Colab instance.'
}
result=nlp_QA(QA_inp)
'''
question='How many parameters does Bert large have?'
context='Bert large is really big... it has 24 layers, for a total of 340M parameters.Altogether it is 1.34 GB so expect it to take a couple minutes to download to your Colab instance.'
title = 'Question Answering demo with Albert QA transformer and gradio'
gr.Interface(inference,inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context"), gr.inputs.Textbox(lines=2, default=question, label="Question")],
outputs=[gr.outputs.Textbox(type="auto",label="Answer")],title = title,theme = "peach").launch()