File size: 3,798 Bytes
9b74ec6 06f0356 9b74ec6 409504b 7c6c308 9b74ec6 7c6c308 9b74ec6 a7d6d41 a3a9074 7c6c308 9b74ec6 bd2f642 7c6c308 bd2f642 9b74ec6 7c6c308 9b74ec6 a3a9074 74c6866 a3a9074 7c6c308 9b74ec6 7c6c308 9b74ec6 7c6c308 9b74ec6 7c6c308 9b74ec6 7c6c308 9b74ec6 7c6c308 9b74ec6 1ba0543 65a2535 1ba0543 7c6c308 4113730 7c6c308 1ba0543 65a2535 7c6c308 1ba0543 7c6c308 1ba0543 7c6c308 1ba0543 7c6c308 9b74ec6 7c6c308 9b74ec6 a3a9074 5464450 7c6c308 5464450 7c6c308 5464450 9b74ec6 a3a9074 7c6c308 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from typing import List
import numpy as np
app = FastAPI()
# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# Request models
class ModifyQueryRequest(BaseModel):
query_string: str
class ModifyQueryRequest_v3(BaseModel):
query_string_list: List[str]
class AnswerQuestionRequest(BaseModel):
question: str
context: List[str]
locations: List[str]
class T5QuestionRequest(BaseModel):
context: str
# Response models
class ModifyQueryResponse(BaseModel):
embeddings: List[List[float]]
class AnswerQuestionResponse(BaseModel):
answer: str
locations: List[str]
class T5Response(BaseModel):
answer: str
# API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
try:
# Generate embeddings
embeddings = model.encode([request.query_string])
return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in modifying query: {str(e)}")
@app.post("/modify_query_v3", response_model=ModifyQueryResponse)
async def modify_query_v3(request: ModifyQueryRequest_v3):
try:
# Generate embeddings for a list of query strings
embeddings = model.encode(request.query_string_list)
return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")
@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
try:
res_locs = []
context_string = ''
corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
query_embeddings = model.encode(request.question, convert_to_tensor=True)
hits = util.semantic_search(query_embeddings, corpus_embeddings)
# Collect relevant contexts
for hit in hits[0]:
if hit['score'] > 0.4:
loc = hit['corpus_id']
res_locs.append(request.locations[loc])
context_string += request.context[loc] + ' '
# If no relevant contexts are found
if not res_locs:
answer = "Sorry, I couldn't find any results for your query. Please try again!"
else:
# Use the question-answering pipeline
QA_input = {
'question': request.question,
'context': context_string.replace('\n', ' ')
}
result = nlp(QA_input)
answer = result['answer']
return AnswerQuestionResponse(answer=answer, locations=res_locs)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")
@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
try:
# Summarize the context
response = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
return T5Response(answer=response[0]["summary_text"])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in T5 summarization: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|