Update app.py
Browse files
app.py
CHANGED
@@ -2,19 +2,15 @@ from fastapi import FastAPI, HTTPException
|
|
2 |
from pydantic import BaseModel
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
from transformers import pipeline
|
5 |
-
|
6 |
|
7 |
|
8 |
-
# Initialize FastAPI app
|
9 |
app = FastAPI()
|
10 |
|
11 |
-
# Load models
|
12 |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
13 |
question_model = "deepset/tinyroberta-squad2"
|
14 |
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
|
15 |
|
16 |
-
#t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
|
17 |
-
#t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
|
18 |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
19 |
|
20 |
# Define request models
|
@@ -80,6 +76,31 @@ async def t5answer(request: T5QuestionRequest):
|
|
80 |
resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
|
81 |
return T5Response(answer = resp[0]["summary_text"])
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
if __name__ == "__main__":
|
84 |
import uvicorn
|
85 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
from transformers import pipeline
|
5 |
+
import numpy as np
|
6 |
|
7 |
|
|
|
8 |
app = FastAPI()
|
9 |
|
|
|
10 |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
11 |
question_model = "deepset/tinyroberta-squad2"
|
12 |
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
|
13 |
|
|
|
|
|
14 |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
15 |
|
16 |
# Define request models
|
|
|
76 |
resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
|
77 |
return T5Response(answer = resp[0]["summary_text"])
|
78 |
|
79 |
+
|
80 |
+
# Define API endpoints
|
81 |
+
@app.post("/modify_query2", response_model=ModifyQueryResponse)
|
82 |
+
async def modify_query(request: ModifyQueryRequest):
|
83 |
+
try:
|
84 |
+
embeddings = optimize_embedding([request.query_string])
|
85 |
+
return ModifyQueryResponse(embeddings=embeddings[0].tolist())
|
86 |
+
except Exception as e:
|
87 |
+
raise HTTPException(status_code=500, detail=str(e))
|
88 |
+
|
89 |
+
|
90 |
+
def optimize_embedding(texts, precision='uint8'):
|
91 |
+
# Step 1: Generate embeddings with 384 dimensions
|
92 |
+
embeddings = model.encode(texts)
|
93 |
+
|
94 |
+
# Step 2: Quantize embeddings to chosen precision (e.g., uint8)
|
95 |
+
if precision == 'uint8':
|
96 |
+
quantized_embeddings = np.array(embeddings, dtype='float32').astype('uint8')
|
97 |
+
elif precision == 'uint16':
|
98 |
+
quantized_embeddings = np.array(embeddings, dtype='float32').astype('uint16')
|
99 |
+
else:
|
100 |
+
raise ValueError("Unsupported precision. Use 'uint8' or 'uint16'.")
|
101 |
+
|
102 |
+
return quantized_embeddings
|
103 |
+
|
104 |
if __name__ == "__main__":
|
105 |
import uvicorn
|
106 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|