File size: 2,419 Bytes
9b74ec6
 
1ba0543
9b74ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ba0543
9b74ec6
 
 
 
 
 
 
1ba0543
9b74ec6
 
 
 
 
 
 
 
 
 
 
 
 
1ba0543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba5ed8e
1ba0543
 
 
 
9b74ec6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, utils
from transformers import pipeline

# Initialize FastAPI app
app = FastAPI()

# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

# Define request models
class ModifyQueryRequest(BaseModel):
    query_string: str

class AnswerQuestionRequest(BaseModel):
    question: str
    context: dict

# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
    embeddings: list

class AnswerQuestionResponse(BaseModel):
    answer: str
    locations: list

# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
    try:
        binary_embeddings = model.encode([request.query_string], precision="binary")
        return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
    try:
        res_locs = []
        context_string = ''
        corpus_embeddings = model.encode(request.context['context'], convert_to_tensor=True)
        query_embeddings = model.encode(request.question, convert_to_tensor=True)
        hits = util.semantic_search(query_embeddings, corpus_embeddings)
        for hit in hits:
            if hit['score'] > .5:
                loc = hit['corpus_id']
                res_locs.append(request.context['locations'][loc])
                context_string += request.context['context'][loc] + ' '
        if len(res_locs) == 0:
            ans = "Sorry, I couldn't find any results for your query."
        else:
            QA_input = {
                'question': request.question,
                'context': context_string.replace('\n',' ')
            }
            result = nlp(QA_input)
            ans = result['answer']
        return AnswerQuestionResponse(answer=ans, locations = res_locs)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)