from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer, util from transformers import pipeline # Initialize FastAPI app app = FastAPI() # Load models model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") question_model = "deepset/tinyroberta-squad2" nlp = pipeline('question-answering', model=question_model, tokenizer=question_model) # Define request models class ModifyQueryRequest(BaseModel): query_string: str class AnswerQuestionRequest(BaseModel): question: str context: list locations: list # Define response models (if needed) class ModifyQueryResponse(BaseModel): embeddings: list class AnswerQuestionResponse(BaseModel): answer: str locations: list # Define API endpoints @app.post("/modify_query", response_model=ModifyQueryResponse) async def modify_query(request: ModifyQueryRequest): try: binary_embeddings = model.encode([request.query_string], precision="binary") return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist()) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/answer_question", response_model=AnswerQuestionResponse) async def answer_question(request: AnswerQuestionRequest): try: res_locs = [] context_string = '' corpus_embeddings = model.encode(request.context, convert_to_tensor=True) query_embeddings = model.encode(request.question, convert_to_tensor=True) hits = util.semantic_search(query_embeddings, corpus_embeddings) for hit in hits: if hit['score'] > .5: loc = hit['corpus_id'] res_locs.append(request.locations[loc]) context_string += request.context[loc] + ' ' if len(res_locs) == 0: ans = "Sorry, I couldn't find any results for your query." else: QA_input = { 'question': request.question, 'context': context_string.replace('\n',' ') } result = nlp(QA_input) ans = result['answer'] return AnswerQuestionResponse(answer=ans, locations = res_locs) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)