MCQ-Generation / app.py
rkp74's picture
app file
e99804d
from fastT5 import get_onnx_model,get_onnx_runtime_sessions,OnnxT5
from transformers import AutoTokenizer
from pathlib import Path
import os
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class QuestionRequest(BaseModel):
context: str
answer: str
class QuestionResponse(BaseModel):
question: str
trained_model_path = './t5_squad_v1/'
pretrained_model_name = Path(trained_model_path).stem
encoder_path = os.path.join(trained_model_path,f"{pretrained_model_name}-encoder-quantized.onnx")
decoder_path = os.path.join(trained_model_path,f"{pretrained_model_name}-decoder-quantized.onnx")
init_decoder_path = os.path.join(trained_model_path,f"{pretrained_model_name}-init-decoder-quantized.onnx")
model_paths = encoder_path, decoder_path, init_decoder_path
model_sessions = get_onnx_runtime_sessions(model_paths)
model = OnnxT5(trained_model_path, model_sessions)
tokenizer = AutoTokenizer.from_pretrained(trained_model_path)
def get_question(sentence,answer,mdl,tknizer):
text = "context: {} answer: {}".format(sentence,answer)
print (text)
max_len = 256
encoding = tknizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt")
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
outs = mdl.generate(input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=5,
num_return_sequences=1,
no_repeat_ngram_size=2,
max_length=128)
dec = [tknizer.decode(ids,skip_special_tokens=True) for ids in outs]
Question = dec[0].replace("question:","")
Question= Question.strip()
return Question
@app.get('/')
def index():
return {'message':'hello world'}
@app.post("/getquestion", response_model=QuestionResponse)
def getquestion(request: QuestionRequest):
context = request.context
answer = request.answer
ques = get_question(context,answer,model,tokenizer)
return QuestionResponse(question=ques)