Tonyivan commited on
Commit
65a2535
·
verified ·
1 Parent(s): e0452b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -28
app.py CHANGED
@@ -1,23 +1,15 @@
1
- import logging
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from sentence_transformers import SentenceTransformer, util
5
  from transformers import pipeline
6
 
7
- # Set up logging
8
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
9
- logger = logging.getLogger(__name__)
10
-
11
  # Initialize FastAPI app
12
  app = FastAPI()
13
 
14
- # Log model loading
15
- logger.info("Loading models...")
16
  # Load models
17
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
18
  question_model = "deepset/tinyroberta-squad2"
19
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
20
- logger.info("Models loaded successfully.")
21
 
22
  # Define request models
23
  class ModifyQueryRequest(BaseModel):
@@ -25,9 +17,10 @@ class ModifyQueryRequest(BaseModel):
25
 
26
  class AnswerQuestionRequest(BaseModel):
27
  question: str
28
- context: dict
 
29
 
30
- # Define response models
31
  class ModifyQueryResponse(BaseModel):
32
  embeddings: list
33
 
@@ -38,50 +31,38 @@ class AnswerQuestionResponse(BaseModel):
38
  # Define API endpoints
39
  @app.post("/modify_query", response_model=ModifyQueryResponse)
40
  async def modify_query(request: ModifyQueryRequest):
41
- logger.info(f"Received /modify_query request: {request.query_string}")
42
  try:
43
  binary_embeddings = model.encode([request.query_string], precision="binary")
44
- logger.info("Embeddings generated successfully.")
45
  return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
46
  except Exception as e:
47
- logger.error(f"Error generating embeddings: {str(e)}")
48
  raise HTTPException(status_code=500, detail=str(e))
49
 
50
  @app.post("/answer_question", response_model=AnswerQuestionResponse)
51
  async def answer_question(request: AnswerQuestionRequest):
52
- logger.info(f"Received /answer_question request: {request.question}")
53
  try:
54
  res_locs = []
55
  context_string = ''
56
-
57
- corpus_embeddings = model.encode(request.context['context'], convert_to_tensor=True)
58
  query_embeddings = model.encode(request.question, convert_to_tensor=True)
59
  hits = util.semantic_search(query_embeddings, corpus_embeddings)
60
-
61
  for hit in hits:
62
- if hit['score'] > 0.5:
63
  loc = hit['corpus_id']
64
- res_locs.append(request.context['locations'][loc])
65
- context_string += request.context['context'][loc] + ' '
66
-
67
  if len(res_locs) == 0:
68
  ans = "Sorry, I couldn't find any results for your query."
69
- logger.info("No relevant context found.")
70
  else:
71
  QA_input = {
72
  'question': request.question,
73
- 'context': context_string.replace('\n', ' ')
74
  }
75
  result = nlp(QA_input)
76
  ans = result['answer']
77
- logger.info("Answer generated successfully.")
78
-
79
- return AnswerQuestionResponse(answer=ans, locations=res_locs)
80
  except Exception as e:
81
- logger.error(f"Error answering question: {str(e)}")
82
  raise HTTPException(status_code=500, detail=str(e))
83
 
84
  if __name__ == "__main__":
85
  import uvicorn
86
- logger.info("Starting FastAPI server...")
87
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
 
 
 
 
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
 
 
9
  # Load models
10
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
11
  question_model = "deepset/tinyroberta-squad2"
12
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
 
13
 
14
  # Define request models
15
  class ModifyQueryRequest(BaseModel):
 
17
 
18
  class AnswerQuestionRequest(BaseModel):
19
  question: str
20
+ context: list
21
+ locations: list
22
 
23
+ # Define response models (if needed)
24
  class ModifyQueryResponse(BaseModel):
25
  embeddings: list
26
 
 
31
  # Define API endpoints
32
  @app.post("/modify_query", response_model=ModifyQueryResponse)
33
  async def modify_query(request: ModifyQueryRequest):
 
34
  try:
35
  binary_embeddings = model.encode([request.query_string], precision="binary")
 
36
  return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
37
  except Exception as e:
 
38
  raise HTTPException(status_code=500, detail=str(e))
39
 
40
  @app.post("/answer_question", response_model=AnswerQuestionResponse)
41
  async def answer_question(request: AnswerQuestionRequest):
 
42
  try:
43
  res_locs = []
44
  context_string = ''
45
+ corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
 
46
  query_embeddings = model.encode(request.question, convert_to_tensor=True)
47
  hits = util.semantic_search(query_embeddings, corpus_embeddings)
 
48
  for hit in hits:
49
+ if hit['score'] > .5:
50
  loc = hit['corpus_id']
51
+ res_locs.append(request.locations[loc])
52
+ context_string += request.context[loc] + ' '
 
53
  if len(res_locs) == 0:
54
  ans = "Sorry, I couldn't find any results for your query."
 
55
  else:
56
  QA_input = {
57
  'question': request.question,
58
+ 'context': context_string.replace('\n',' ')
59
  }
60
  result = nlp(QA_input)
61
  ans = result['answer']
62
+ return AnswerQuestionResponse(answer=ans, locations = res_locs)
 
 
63
  except Exception as e:
 
64
  raise HTTPException(status_code=500, detail=str(e))
65
 
66
  if __name__ == "__main__":
67
  import uvicorn
 
68
  uvicorn.run(app, host="0.0.0.0", port=8000)