rkp74 commited on
Commit
e99804d
·
1 Parent(s): bb19677
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastT5 import get_onnx_model,get_onnx_runtime_sessions,OnnxT5
2
+ from transformers import AutoTokenizer
3
+ from pathlib import Path
4
+ import os
5
+ from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
+
8
+ app = FastAPI()
9
+
10
+
11
+ class QuestionRequest(BaseModel):
12
+ context: str
13
+ answer: str
14
+
15
+ class QuestionResponse(BaseModel):
16
+ question: str
17
+
18
+ trained_model_path = './t5_squad_v1/'
19
+
20
+ pretrained_model_name = Path(trained_model_path).stem
21
+
22
+
23
+ encoder_path = os.path.join(trained_model_path,f"{pretrained_model_name}-encoder-quantized.onnx")
24
+ decoder_path = os.path.join(trained_model_path,f"{pretrained_model_name}-decoder-quantized.onnx")
25
+ init_decoder_path = os.path.join(trained_model_path,f"{pretrained_model_name}-init-decoder-quantized.onnx")
26
+
27
+ model_paths = encoder_path, decoder_path, init_decoder_path
28
+ model_sessions = get_onnx_runtime_sessions(model_paths)
29
+ model = OnnxT5(trained_model_path, model_sessions)
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(trained_model_path)
32
+
33
+
34
+ def get_question(sentence,answer,mdl,tknizer):
35
+ text = "context: {} answer: {}".format(sentence,answer)
36
+ print (text)
37
+ max_len = 256
38
+ encoding = tknizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt")
39
+
40
+ input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
41
+
42
+ outs = mdl.generate(input_ids=input_ids,
43
+ attention_mask=attention_mask,
44
+ early_stopping=True,
45
+ num_beams=5,
46
+ num_return_sequences=1,
47
+ no_repeat_ngram_size=2,
48
+ max_length=128)
49
+
50
+
51
+ dec = [tknizer.decode(ids,skip_special_tokens=True) for ids in outs]
52
+
53
+
54
+ Question = dec[0].replace("question:","")
55
+ Question= Question.strip()
56
+ return Question
57
+
58
+
59
+
60
+ @app.get('/')
61
+ def index():
62
+ return {'message':'hello world'}
63
+
64
+ @app.post("/getquestion", response_model=QuestionResponse)
65
+ def getquestion(request: QuestionRequest):
66
+ context = request.context
67
+ answer = request.answer
68
+ ques = get_question(context,answer,model,tokenizer)
69
+ return QuestionResponse(question=ques)
70
+