Spaces:
Sleeping
Sleeping
feat: update
Browse files
app.py
CHANGED
@@ -1,69 +1,45 @@
|
|
1 |
from typing import List, Literal
|
2 |
from pydantic import BaseModel, Field
|
3 |
-
import gradio as gr
|
4 |
from fastapi import FastAPI, APIRouter, Request
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
import uvicorn
|
8 |
-
import requests
|
9 |
|
10 |
-
#
|
11 |
app = FastAPI()
|
12 |
|
13 |
-
#
|
14 |
app.add_middleware(
|
15 |
CORSMiddleware,
|
16 |
-
allow_origins=["*"],
|
17 |
allow_credentials=True,
|
18 |
allow_methods=["*"],
|
19 |
allow_headers=["*"],
|
20 |
)
|
21 |
|
22 |
-
#
|
23 |
-
model = SentenceTransformer(
|
24 |
|
25 |
-
#
|
26 |
class PostEmbeddings(BaseModel):
|
27 |
type: Literal['default', 'disease', 'gte'] = Field(default='default')
|
28 |
sentences: List[str]
|
29 |
|
30 |
-
#
|
31 |
-
router = APIRouter(
|
32 |
-
prefix="/retrieval",
|
33 |
-
tags=["retrieval"],
|
34 |
-
responses={404: {"description": "Not found"}},
|
35 |
-
)
|
36 |
|
37 |
-
@
|
38 |
-
def post_embeddings(data: PostEmbeddings):
|
39 |
embeddings = model.encode(data.sentences)
|
40 |
-
return {
|
41 |
-
'data': {
|
42 |
-
'embeddings': embeddings.tolist(),
|
43 |
-
'type': data.type
|
44 |
-
}
|
45 |
-
}
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences})
|
50 |
-
return response.json()["data"]
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
|
56 |
-
outputs=gr.JSON(label="Kết quả mã hóa"),
|
57 |
-
title="Mô hình GTE Multilingual",
|
58 |
-
description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual."
|
59 |
-
)
|
60 |
|
61 |
-
#
|
62 |
if __name__ == "__main__":
|
63 |
-
|
64 |
-
|
65 |
-
# Khởi động FastAPI trong một thread riêng
|
66 |
-
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
|
67 |
-
|
68 |
-
# Khởi động Gradio
|
69 |
-
demo.launch()
|
|
|
1 |
from typing import List, Literal
|
2 |
from pydantic import BaseModel, Field
|
|
|
3 |
from fastapi import FastAPI, APIRouter, Request
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
import uvicorn
|
|
|
7 |
|
8 |
+
# Initialize FastAPI app
|
9 |
app = FastAPI()
|
10 |
|
11 |
+
# CORS middleware
|
12 |
app.add_middleware(
|
13 |
CORSMiddleware,
|
14 |
+
allow_origins=["*"],
|
15 |
allow_credentials=True,
|
16 |
allow_methods=["*"],
|
17 |
allow_headers=["*"],
|
18 |
)
|
19 |
|
20 |
+
# Load model
|
21 |
+
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
|
22 |
|
23 |
+
# Define data model
|
24 |
class PostEmbeddings(BaseModel):
|
25 |
type: Literal['default', 'disease', 'gte'] = Field(default='default')
|
26 |
sentences: List[str]
|
27 |
|
28 |
+
# Router for embeddings
|
29 |
+
router = APIRouter(prefix="/retrieval", tags=["retrieval"])
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
@router.post('/embeddings')
|
32 |
+
def post_embeddings(request: Request, data: PostEmbeddings):
|
33 |
embeddings = model.encode(data.sentences)
|
34 |
+
return {"embeddings": embeddings.tolist()}
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Include router
|
37 |
+
app.include_router(router)
|
|
|
|
|
38 |
|
39 |
+
# Define main function to run the app
|
40 |
+
def main():
|
41 |
+
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# Run the app if this script is the main module
|
44 |
if __name__ == "__main__":
|
45 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|