Tonyivan commited on
Commit
a3a9074
·
verified ·
1 Parent(s): 18416fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -2,6 +2,8 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
 
 
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
@@ -11,6 +13,9 @@ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
11
  question_model = "deepset/tinyroberta-squad2"
12
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
13
 
 
 
 
14
  # Define request models
15
  class ModifyQueryRequest(BaseModel):
16
  query_string: str
@@ -20,6 +25,13 @@ class AnswerQuestionRequest(BaseModel):
20
  context: list
21
  locations: list
22
 
 
 
 
 
 
 
 
23
  # Define response models (if needed)
24
  class ModifyQueryResponse(BaseModel):
25
  embeddings: list
@@ -63,6 +75,15 @@ async def answer_question(request: AnswerQuestionRequest):
63
  except Exception as e:
64
  raise HTTPException(status_code=500, detail=str(e))
65
 
 
 
 
 
 
 
 
 
66
  if __name__ == "__main__":
67
  import uvicorn
68
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
6
+
7
 
8
  # Initialize FastAPI app
9
  app = FastAPI()
 
13
  question_model = "deepset/tinyroberta-squad2"
14
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
15
 
16
+ t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
17
+ t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
18
+
19
  # Define request models
20
  class ModifyQueryRequest(BaseModel):
21
  query_string: str
 
25
  context: list
26
  locations: list
27
 
28
+ class T5QuestionRequest(BaseModel):
29
+ question: str
30
+ context: list
31
+
32
+ class T5Response(BaseModel):
33
+ answer: str
34
+
35
  # Define response models (if needed)
36
  class ModifyQueryResponse(BaseModel):
37
  embeddings: list
 
75
  except Exception as e:
76
  raise HTTPException(status_code=500, detail=str(e))
77
 
78
+ @app.post("/t5answer", response_model=T5Response)
79
+ async def t5answer(request: T5QuestionRequest):
80
+ input_text = request.question + ":" + request.context
81
+ input_ids = t5tokenizer(input_text, return_tensors="pt").input_ids
82
+ outputs = t5model.generate(input_ids)
83
+ resp = t5tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ return T5Response(answer = resp)
85
+
86
  if __name__ == "__main__":
87
  import uvicorn
88
  uvicorn.run(app, host="0.0.0.0", port=8000)
89
+