File size: 1,659 Bytes
9b74ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import pipeline

# Initialize FastAPI app
app = FastAPI()

# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

# Define request models
class ModifyQueryRequest(BaseModel):
    query_string: str

class AnswerQuestionRequest(BaseModel):
    question: str
    context: str

# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
    embeddings: list

class AnswerQuestionResponse(BaseModel):
    answer: str

# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
    try:
        binary_embeddings = model.encode([request.query_string], precision="binary")
        return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
    try:
        QA_input = {
            'question': request.question,
            'context': request.context
        }
        result = nlp(QA_input)
        return AnswerQuestionResponse(answer=result['answer'])
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)