Tonyivan commited on
Commit
d1597fa
·
verified ·
1 Parent(s): 9c1be03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -19,19 +19,10 @@ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
  class ModifyQueryRequest_v3(BaseModel):
20
  query_string_list: List[str]
21
 
22
- class AnswerQuestionRequest(BaseModel):
23
- question: str
24
- context: List[str]
25
- locations: List[str]
26
 
27
  class T5QuestionRequest(BaseModel):
28
  context: str
29
 
30
-
31
- class AnswerQuestionResponse(BaseModel):
32
- answer: str
33
- locations: List[str]
34
-
35
  class T5Response(BaseModel):
36
  answer: str
37
 
@@ -55,21 +46,22 @@ async def modify_query_v3(request: Request):
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")
57
 
58
- @app.post("/answer_question", response_model=AnswerQuestionResponse)
59
- async def answer_question(request: AnswerQuestionRequest):
60
  try:
 
61
  res_locs = []
62
  context_string = ''
63
- corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
64
- query_embeddings = model.encode(request.question, convert_to_tensor=True)
65
  hits = util.semantic_search(query_embeddings, corpus_embeddings)
66
 
67
  # Collect relevant contexts
68
  for hit in hits[0]:
69
  if hit['score'] > 0.4:
70
  loc = hit['corpus_id']
71
- res_locs.append(request.locations[loc])
72
- context_string += request.context[loc] + ' '
73
 
74
  # If no relevant contexts are found
75
  if not res_locs:
@@ -77,13 +69,13 @@ async def answer_question(request: AnswerQuestionRequest):
77
  else:
78
  # Use the question-answering pipeline
79
  QA_input = {
80
- 'question': request.question,
81
  'context': context_string.replace('\n', ' ')
82
  }
83
  result = nlp(QA_input)
84
  answer = result['answer']
85
 
86
- return AnswerQuestionResponse(answer=answer, locations=res_locs)
87
  except Exception as e:
88
  raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")
89
 
 
19
  class ModifyQueryRequest_v3(BaseModel):
20
  query_string_list: List[str]
21
 
 
 
 
 
22
 
23
  class T5QuestionRequest(BaseModel):
24
  context: str
25
 
 
 
 
 
 
26
  class T5Response(BaseModel):
27
  answer: str
28
 
 
46
  except Exception as e:
47
  raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")
48
 
49
+ @app.post("/answer_question")
50
+ async def answer_question(request: Request):
51
  try:
52
+ raw_data = await request.json()
53
  res_locs = []
54
  context_string = ''
55
+ corpus_embeddings = model.encode(raw_data['context'], convert_to_tensor=True)
56
+ query_embeddings = model.encode(raw_data['question'], convert_to_tensor=True)
57
  hits = util.semantic_search(query_embeddings, corpus_embeddings)
58
 
59
  # Collect relevant contexts
60
  for hit in hits[0]:
61
  if hit['score'] > 0.4:
62
  loc = hit['corpus_id']
63
+ res_locs.append(raw_data['locations'][loc])
64
+ context_string += raw_data['context'][loc] + ' '
65
 
66
  # If no relevant contexts are found
67
  if not res_locs:
 
69
  else:
70
  # Use the question-answering pipeline
71
  QA_input = {
72
+ 'question': raw_data['question'],
73
  'context': context_string.replace('\n', ' ')
74
  }
75
  result = nlp(QA_input)
76
  answer = result['answer']
77
 
78
+ return JSONResponse(content={'answer':answer, "location":res_locs})
79
  except Exception as e:
80
  raise HTTPException(status_code=500, detail=f"Error in answering question: {str(e)}")
81