File size: 3,798 Bytes
9b74ec6
 
06f0356
9b74ec6
409504b
7c6c308
9b74ec6
 
 
7c6c308
9b74ec6
 
 
 
a7d6d41
a3a9074
7c6c308
9b74ec6
 
 
bd2f642
7c6c308
bd2f642
9b74ec6
 
7c6c308
 
9b74ec6
a3a9074
74c6866
a3a9074
7c6c308
9b74ec6
7c6c308
9b74ec6
 
 
7c6c308
9b74ec6
7c6c308
 
 
 
9b74ec6
 
 
7c6c308
 
 
 
 
 
 
 
 
 
 
 
9b74ec6
7c6c308
9b74ec6
 
 
 
1ba0543
 
65a2535
1ba0543
 
7c6c308
 
4113730
7c6c308
1ba0543
65a2535
 
7c6c308
 
 
 
1ba0543
7c6c308
1ba0543
 
7c6c308
1ba0543
 
7c6c308
 
 
9b74ec6
7c6c308
9b74ec6
a3a9074
 
5464450
7c6c308
 
 
5464450
7c6c308
5464450
9b74ec6
 
 
a3a9074
7c6c308
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from typing import List
import numpy as np

app = FastAPI()

# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

# Request models
class ModifyQueryRequest(BaseModel):
    query_string: str

class ModifyQueryRequest_v3(BaseModel):
    query_string_list: List[str]

class AnswerQuestionRequest(BaseModel):
    question: str
    context: List[str]
    locations: List[str]

class T5QuestionRequest(BaseModel):
    context: str

# Response models
class ModifyQueryResponse(BaseModel):
    embeddings: List[List[float]]

class AnswerQuestionResponse(BaseModel):
    answer: str
    locations: List[str]

class T5Response(BaseModel):
    answer: str

# API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
    try:
        # Generate embeddings
        embeddings = model.encode([request.query_string])
        return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in modifying query: {str(e)}")

@app.post("/modify_query_v3", response_model=ModifyQueryResponse)
async def modify_query_v3(request: ModifyQueryRequest_v3):
    try:
        # Generate embeddings for a list of query strings
        embeddings = model.encode(request.query_string_list)
        return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")

@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
    try:
        res_locs = []
        context_string = ''
        corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
        query_embeddings = model.encode(request.question, convert_to_tensor=True)
        hits = util.semantic_search(query_embeddings, corpus_embeddings)
        
        # Collect relevant contexts
        for hit in hits[0]:
            if hit['score'] > 0.4:
                loc = hit['corpus_id']
                res_locs.append(request.locations[loc])
                context_string += request.context[loc] + ' '
        
        # If no relevant contexts are found
        if not res_locs:
            answer = "Sorry, I couldn't find any results for your query. Please try again!"
        else:
            # Use the question-answering pipeline
            QA_input = {
                'question': request.question,
                'context': context_string.replace('\n', ' ')
            }
            result = nlp(QA_input)
            answer = result['answer']
        
        return AnswerQuestionResponse(answer=answer, locations=res_locs)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")

@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
    try:
        # Summarize the context
        response = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
        return T5Response(answer=response[0]["summary_text"])
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in T5 summarization: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)