File size: 3,482 Bytes
9c1be03 9b74ec6 06f0356 9b74ec6 409504b 7c6c308 9b74ec6 7c6c308 9b74ec6 a7d6d41 a3a9074 9b74ec6 bd2f642 7c6c308 bd2f642 9b74ec6 a3a9074 74c6866 a3a9074 7c6c308 bbd40ae 9b74ec6 bbd40ae 7c6c308 f57466f 7c6c308 bbd40ae 7c6c308 bbd40ae 9b74ec6 7c6c308 9b74ec6 d1597fa 9b74ec6 d1597fa 1ba0543 d1597fa 1ba0543 7c6c308 4113730 7c6c308 1ba0543 d1597fa 7c6c308 1ba0543 7c6c308 1ba0543 d1597fa 7c6c308 1ba0543 7c6c308 d1597fa 9b74ec6 7c6c308 9b74ec6 a3a9074 5464450 7c6c308 5464450 7c6c308 5464450 9b74ec6 a3a9074 7c6c308 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from typing import List
import numpy as np
app = FastAPI()
# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
class ModifyQueryRequest_v3(BaseModel):
query_string_list: List[str]
class T5QuestionRequest(BaseModel):
context: str
class T5Response(BaseModel):
answer: str
# API endpoints
@app.post("/modify_query")
async def modify_query(request: Request):
try:
raw_data = await request.json()
binary_embeddings = model.encode([raw_data['query_string']], precision="binary")
return JSONResponse(content={'embeddings':binary_embeddings[0].tolist()})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/modify_query_v3")
async def modify_query_v3(request: Request):
try:
# Generate embeddings for a list of query strings
raw_data = await request.json()
embeddings = model.encode(raw_data['query_string_list'])
return JSONResponse(content={'embeddings':[emb.tolist() for emb in embeddings]})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")
@app.post("/answer_question")
async def answer_question(request: Request):
try:
raw_data = await request.json()
res_locs = []
context_string = ''
corpus_embeddings = model.encode(raw_data['context'], convert_to_tensor=True)
query_embeddings = model.encode(raw_data['question'], convert_to_tensor=True)
hits = util.semantic_search(query_embeddings, corpus_embeddings)
# Collect relevant contexts
for hit in hits[0]:
if hit['score'] > 0.4:
loc = hit['corpus_id']
res_locs.append(raw_data['locations'][loc])
context_string += raw_data['context'][loc] + ' '
# If no relevant contexts are found
if not res_locs:
answer = "Sorry, I couldn't find any results for your query. Please try again!"
else:
# Use the question-answering pipeline
QA_input = {
'question': raw_data['question'],
'context': context_string.replace('\n', ' ')
}
result = nlp(QA_input)
answer = result['answer']
return JSONResponse(content={'answer':answer, "location":res_locs})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")
@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
try:
# Summarize the context
response = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
return T5Response(answer=response[0]["summary_text"])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in T5 summarization: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|