Tonyivan commited on
Commit
1ba0543
·
verified ·
1 Parent(s): 3266774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
4
  from transformers import pipeline
5
 
6
  # Initialize FastAPI app
@@ -17,7 +17,7 @@ class ModifyQueryRequest(BaseModel):
17
 
18
  class AnswerQuestionRequest(BaseModel):
19
  question: str
20
- context: str
21
 
22
  # Define response models (if needed)
23
  class ModifyQueryResponse(BaseModel):
@@ -25,6 +25,7 @@ class ModifyQueryResponse(BaseModel):
25
 
26
  class AnswerQuestionResponse(BaseModel):
27
  answer: str
 
28
 
29
  # Define API endpoints
30
  @app.post("/modify_query", response_model=ModifyQueryResponse)
@@ -38,12 +39,26 @@ async def modify_query(request: ModifyQueryRequest):
38
  @app.post("/answer_question", response_model=AnswerQuestionResponse)
39
  async def answer_question(request: AnswerQuestionRequest):
40
  try:
41
- QA_input = {
42
- 'question': request.question,
43
- 'context': request.context
44
- }
45
- result = nlp(QA_input)
46
- return AnswerQuestionResponse(answer=result['answer'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  except Exception as e:
48
  raise HTTPException(status_code=500, detail=str(e))
49
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer, utils
4
  from transformers import pipeline
5
 
6
  # Initialize FastAPI app
 
17
 
18
  class AnswerQuestionRequest(BaseModel):
19
  question: str
20
+ context: dict
21
 
22
  # Define response models (if needed)
23
  class ModifyQueryResponse(BaseModel):
 
25
 
26
  class AnswerQuestionResponse(BaseModel):
27
  answer: str
28
+ locations: list
29
 
30
  # Define API endpoints
31
  @app.post("/modify_query", response_model=ModifyQueryResponse)
 
39
  @app.post("/answer_question", response_model=AnswerQuestionResponse)
40
  async def answer_question(request: AnswerQuestionRequest):
41
  try:
42
+ res_locs = []
43
+ context_string = ''
44
+ corpus_embeddings = model.encode(request.context['context'], convert_to_tensor=True)
45
+ query_embeddings = model.encode(request.question, convert_to_tensor=True)
46
+ hits = util.semantic_search(query_embeddings, corpus_embeddings)
47
+ for hit in hits:
48
+ if hit['score'] > .5:
49
+ loc = hit['corpus_id']
50
+ res_locs.append(request.context['locations'][loc])
51
+ context_string += request.context['context'][loc] + ' '
52
+ if len(res_locs) == 0:
53
+ ans = "Sorry, I couldn't find any results for your query."
54
+ else:
55
+ QA_input = {
56
+ 'question': request.question,
57
+ 'context': context_string.replace.('\n',' ')
58
+ }
59
+ result = nlp(QA_input)
60
+ ans = result['answer']
61
+ return AnswerQuestionResponse(answer=ans, locations = res_locs)
62
  except Exception as e:
63
  raise HTTPException(status_code=500, detail=str(e))
64