Tonyivan commited on
Commit
a7d6d41
·
verified ·
1 Parent(s): 74c6866

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
- from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
 
8
  # Initialize FastAPI app
@@ -13,8 +13,9 @@ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
13
  question_model = "deepset/tinyroberta-squad2"
14
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
15
 
16
- t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
17
- t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
 
18
 
19
  # Define request models
20
  class ModifyQueryRequest(BaseModel):
@@ -26,7 +27,6 @@ class AnswerQuestionRequest(BaseModel):
26
  locations: list
27
 
28
  class T5QuestionRequest(BaseModel):
29
- question: str
30
  context: str
31
 
32
  class T5Response(BaseModel):
@@ -77,11 +77,8 @@ async def answer_question(request: AnswerQuestionRequest):
77
 
78
  @app.post("/t5answer", response_model=T5Response)
79
  async def t5answer(request: T5QuestionRequest):
80
- input_text = request.question + ":" + request.context
81
- input_ids = t5tokenizer(input_text, return_tensors="pt").input_ids
82
- outputs = t5model.generate(input_ids)
83
- resp = t5tokenizer.decode(outputs[0], skip_special_tokens=True)
84
- return T5Response(answer = resp)
85
 
86
  if __name__ == "__main__":
87
  import uvicorn
 
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
+ #from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
 
8
  # Initialize FastAPI app
 
13
  question_model = "deepset/tinyroberta-squad2"
14
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
15
 
16
+ #t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
17
+ #t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
18
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
 
20
  # Define request models
21
  class ModifyQueryRequest(BaseModel):
 
27
  locations: list
28
 
29
  class T5QuestionRequest(BaseModel):
 
30
  context: str
31
 
32
  class T5Response(BaseModel):
 
77
 
78
  @app.post("/t5answer", response_model=T5Response)
79
  async def t5answer(request: T5QuestionRequest):
80
+ resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
81
+ return T5Response(answer = resp[0]["summary_text"])
 
 
 
82
 
83
  if __name__ == "__main__":
84
  import uvicorn