File size: 3,482 Bytes
9c1be03
 
9b74ec6
06f0356
9b74ec6
409504b
7c6c308
9b74ec6
 
 
7c6c308
9b74ec6
 
 
 
a7d6d41
a3a9074
9b74ec6
bd2f642
7c6c308
bd2f642
9b74ec6
a3a9074
74c6866
a3a9074
7c6c308
 
 
 
bbd40ae
 
9b74ec6
bbd40ae
 
 
7c6c308
f57466f
7c6c308
bbd40ae
 
7c6c308
 
bbd40ae
 
 
9b74ec6
7c6c308
9b74ec6
d1597fa
 
9b74ec6
d1597fa
1ba0543
 
d1597fa
 
1ba0543
7c6c308
 
4113730
7c6c308
1ba0543
d1597fa
 
7c6c308
 
 
 
1ba0543
7c6c308
1ba0543
d1597fa
7c6c308
1ba0543
 
7c6c308
 
d1597fa
9b74ec6
7c6c308
9b74ec6
a3a9074
 
5464450
7c6c308
 
 
5464450
7c6c308
5464450
9b74ec6
 
 
a3a9074
7c6c308
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from typing import List
import numpy as np

app = FastAPI()

# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")


class ModifyQueryRequest_v3(BaseModel):
    query_string_list: List[str]


class T5QuestionRequest(BaseModel):
    context: str

class T5Response(BaseModel):
    answer: str

# API endpoints
@app.post("/modify_query")
async def modify_query(request: Request):
    try:
        raw_data = await request.json()
        binary_embeddings = model.encode([raw_data['query_string']], precision="binary")
        return JSONResponse(content={'embeddings':binary_embeddings[0].tolist()})
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/modify_query_v3")
async def modify_query_v3(request: Request):
    try:
        # Generate embeddings for a list of query strings
        raw_data = await request.json()
        embeddings = model.encode(raw_data['query_string_list'])
        return JSONResponse(content={'embeddings':[emb.tolist() for emb in embeddings]})
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")

@app.post("/answer_question")
async def answer_question(request: Request):
    try:
        raw_data = await request.json()
        res_locs = []
        context_string = ''
        corpus_embeddings = model.encode(raw_data['context'], convert_to_tensor=True)
        query_embeddings = model.encode(raw_data['question'], convert_to_tensor=True)
        hits = util.semantic_search(query_embeddings, corpus_embeddings)
        
        # Collect relevant contexts
        for hit in hits[0]:
            if hit['score'] > 0.4:
                loc = hit['corpus_id']
                res_locs.append(raw_data['locations'][loc])
                context_string += raw_data['context'][loc] + ' '
        
        # If no relevant contexts are found
        if not res_locs:
            answer = "Sorry, I couldn't find any results for your query. Please try again!"
        else:
            # Use the question-answering pipeline
            QA_input = {
                'question': raw_data['question'],
                'context': context_string.replace('\n', ' ')
            }
            result = nlp(QA_input)
            answer = result['answer']
        
        return JSONResponse(content={'answer':answer, "location":res_locs})
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")

@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
    try:
        # Summarize the context
        response = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
        return T5Response(answer=response[0]["summary_text"])
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in T5 summarization: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)