Tonyivan commited on
Commit
5464450
·
verified ·
1 Parent(s): 74b5df4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -2,19 +2,15 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
- #from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
 
8
- # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
- # Load models
12
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
13
  question_model = "deepset/tinyroberta-squad2"
14
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
15
 
16
- #t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
17
- #t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
18
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
 
20
  # Define request models
@@ -80,6 +76,31 @@ async def t5answer(request: T5QuestionRequest):
80
  resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
81
  return T5Response(answer = resp[0]["summary_text"])
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if __name__ == "__main__":
84
  import uvicorn
85
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
+ import numpy as np
6
 
7
 
 
8
  app = FastAPI()
9
 
 
10
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
11
  question_model = "deepset/tinyroberta-squad2"
12
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
13
 
 
 
14
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
15
 
16
  # Define request models
 
76
  resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
77
  return T5Response(answer = resp[0]["summary_text"])
78
 
79
+
80
+ # Define API endpoints
81
+ @app.post("/modify_query2", response_model=ModifyQueryResponse)
82
+ async def modify_query(request: ModifyQueryRequest):
83
+ try:
84
+ embeddings = optimize_embedding([request.query_string])
85
+ return ModifyQueryResponse(embeddings=embeddings[0].tolist())
86
+ except Exception as e:
87
+ raise HTTPException(status_code=500, detail=str(e))
88
+
89
+
90
+ def optimize_embedding(texts, precision='uint8'):
91
+ # Step 1: Generate embeddings with 384 dimensions
92
+ embeddings = model.encode(texts)
93
+
94
+ # Step 2: Quantize embeddings to chosen precision (e.g., uint8)
95
+ if precision == 'uint8':
96
+ quantized_embeddings = np.array(embeddings, dtype='float32').astype('uint8')
97
+ elif precision == 'uint16':
98
+ quantized_embeddings = np.array(embeddings, dtype='float32').astype('uint16')
99
+ else:
100
+ raise ValueError("Unsupported precision. Use 'uint8' or 'uint16'.")
101
+
102
+ return quantized_embeddings
103
+
104
  if __name__ == "__main__":
105
  import uvicorn
106
  uvicorn.run(app, host="0.0.0.0", port=8000)