Update app.py
Browse files
app.py
CHANGED
@@ -2,53 +2,62 @@ from fastapi import FastAPI, HTTPException
|
|
2 |
from pydantic import BaseModel
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
from transformers import pipeline
|
5 |
-
import numpy as np
|
6 |
from typing import List
|
7 |
-
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
|
|
11 |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
12 |
question_model = "deepset/tinyroberta-squad2"
|
13 |
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
|
14 |
|
15 |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
16 |
|
17 |
-
#
|
18 |
class ModifyQueryRequest(BaseModel):
|
19 |
query_string: str
|
20 |
|
21 |
-
# Define request models
|
22 |
class ModifyQueryRequest_v3(BaseModel):
|
23 |
-
query_string_list: [str]
|
24 |
|
25 |
class AnswerQuestionRequest(BaseModel):
|
26 |
question: str
|
27 |
-
context:
|
28 |
-
locations:
|
29 |
|
30 |
class T5QuestionRequest(BaseModel):
|
31 |
context: str
|
32 |
|
33 |
-
|
34 |
-
answer: str
|
35 |
-
|
36 |
-
# Define response models (if needed)
|
37 |
class ModifyQueryResponse(BaseModel):
|
38 |
-
embeddings:
|
39 |
|
40 |
class AnswerQuestionResponse(BaseModel):
|
41 |
answer: str
|
42 |
-
locations:
|
43 |
|
44 |
-
|
|
|
|
|
|
|
45 |
@app.post("/modify_query", response_model=ModifyQueryResponse)
|
46 |
async def modify_query(request: ModifyQueryRequest):
|
47 |
try:
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
except Exception as e:
|
51 |
-
raise HTTPException(status_code=500, detail=str(e))
|
52 |
|
53 |
@app.post("/answer_question", response_model=AnswerQuestionResponse)
|
54 |
async def answer_question(request: AnswerQuestionRequest):
|
@@ -58,42 +67,41 @@ async def answer_question(request: AnswerQuestionRequest):
|
|
58 |
corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
|
59 |
query_embeddings = model.encode(request.question, convert_to_tensor=True)
|
60 |
hits = util.semantic_search(query_embeddings, corpus_embeddings)
|
|
|
|
|
61 |
for hit in hits[0]:
|
62 |
-
if hit['score'] > .4:
|
63 |
loc = hit['corpus_id']
|
64 |
res_locs.append(request.locations[loc])
|
65 |
context_string += request.context[loc] + ' '
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
else:
|
|
|
69 |
QA_input = {
|
70 |
'question': request.question,
|
71 |
-
'context': context_string.replace('\n',' ')
|
72 |
}
|
73 |
result = nlp(QA_input)
|
74 |
-
|
75 |
-
|
|
|
76 |
except Exception as e:
|
77 |
-
raise HTTPException(status_code=500, detail=str(e))
|
78 |
|
79 |
@app.post("/t5answer", response_model=T5Response)
|
80 |
async def t5answer(request: T5QuestionRequest):
|
81 |
-
resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
|
82 |
-
return T5Response(answer = resp[0]["summary_text"])
|
83 |
-
|
84 |
-
|
85 |
-
# Define API endpoints
|
86 |
-
@app.post("/modify_query_v3", response_model=ModifyQueryResponse)
|
87 |
-
async def modify_query2(request: ModifyQueryRequest_v3):
|
88 |
try:
|
89 |
-
|
90 |
-
|
|
|
91 |
except Exception as e:
|
92 |
-
raise HTTPException(status_code=500, detail=str(e))
|
93 |
-
|
94 |
-
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
import uvicorn
|
98 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
99 |
|
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
from transformers import pipeline
|
|
|
5 |
from typing import List
|
6 |
+
import numpy as np
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
+
# Load models
|
11 |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
12 |
question_model = "deepset/tinyroberta-squad2"
|
13 |
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
|
14 |
|
15 |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
16 |
|
17 |
+
# Request models
|
18 |
class ModifyQueryRequest(BaseModel):
|
19 |
query_string: str
|
20 |
|
|
|
21 |
class ModifyQueryRequest_v3(BaseModel):
|
22 |
+
query_string_list: List[str]
|
23 |
|
24 |
class AnswerQuestionRequest(BaseModel):
|
25 |
question: str
|
26 |
+
context: List[str]
|
27 |
+
locations: List[str]
|
28 |
|
29 |
class T5QuestionRequest(BaseModel):
|
30 |
context: str
|
31 |
|
32 |
+
# Response models
|
|
|
|
|
|
|
33 |
class ModifyQueryResponse(BaseModel):
|
34 |
+
embeddings: List[List[float]]
|
35 |
|
36 |
class AnswerQuestionResponse(BaseModel):
|
37 |
answer: str
|
38 |
+
locations: List[str]
|
39 |
|
40 |
+
class T5Response(BaseModel):
|
41 |
+
answer: str
|
42 |
+
|
43 |
+
# API endpoints
|
44 |
@app.post("/modify_query", response_model=ModifyQueryResponse)
|
45 |
async def modify_query(request: ModifyQueryRequest):
|
46 |
try:
|
47 |
+
# Generate embeddings
|
48 |
+
embeddings = model.encode([request.query_string])
|
49 |
+
return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
|
50 |
+
except Exception as e:
|
51 |
+
raise HTTPException(status_code=500, detail=f"Error in modifying query: {str(e)}")
|
52 |
+
|
53 |
+
@app.post("/modify_query_v3", response_model=ModifyQueryResponse)
|
54 |
+
async def modify_query_v3(request: ModifyQueryRequest_v3):
|
55 |
+
try:
|
56 |
+
# Generate embeddings for a list of query strings
|
57 |
+
embeddings = model.encode(request.query_string_list)
|
58 |
+
return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
|
59 |
except Exception as e:
|
60 |
+
raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")
|
61 |
|
62 |
@app.post("/answer_question", response_model=AnswerQuestionResponse)
|
63 |
async def answer_question(request: AnswerQuestionRequest):
|
|
|
67 |
corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
|
68 |
query_embeddings = model.encode(request.question, convert_to_tensor=True)
|
69 |
hits = util.semantic_search(query_embeddings, corpus_embeddings)
|
70 |
+
|
71 |
+
# Collect relevant contexts
|
72 |
for hit in hits[0]:
|
73 |
+
if hit['score'] > 0.4:
|
74 |
loc = hit['corpus_id']
|
75 |
res_locs.append(request.locations[loc])
|
76 |
context_string += request.context[loc] + ' '
|
77 |
+
|
78 |
+
# If no relevant contexts are found
|
79 |
+
if not res_locs:
|
80 |
+
answer = "Sorry, I couldn't find any results for your query. Please try again!"
|
81 |
else:
|
82 |
+
# Use the question-answering pipeline
|
83 |
QA_input = {
|
84 |
'question': request.question,
|
85 |
+
'context': context_string.replace('\n', ' ')
|
86 |
}
|
87 |
result = nlp(QA_input)
|
88 |
+
answer = result['answer']
|
89 |
+
|
90 |
+
return AnswerQuestionResponse(answer=answer, locations=res_locs)
|
91 |
except Exception as e:
|
92 |
+
raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")
|
93 |
|
94 |
@app.post("/t5answer", response_model=T5Response)
|
95 |
async def t5answer(request: T5QuestionRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
try:
|
97 |
+
# Summarize the context
|
98 |
+
response = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
|
99 |
+
return T5Response(answer=response[0]["summary_text"])
|
100 |
except Exception as e:
|
101 |
+
raise HTTPException(status_code=500, detail=f"Error in T5 summarization: {str(e)}")
|
|
|
|
|
102 |
|
103 |
if __name__ == "__main__":
|
104 |
import uvicorn
|
105 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
106 |
|
107 |
+
|