Spaces:
Sleeping
Sleeping
feat: update fast API
Browse files
app.py
CHANGED
@@ -1,26 +1,69 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# Tải mô hình
|
5 |
-
model = SentenceTransformer(
|
6 |
-
trust_remote_code=True)
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# Tạo giao diện Gradio
|
17 |
demo = gr.Interface(
|
18 |
-
fn=
|
19 |
-
inputs=gr.
|
20 |
-
outputs=gr.
|
21 |
title="Mô hình GTE Multilingual",
|
22 |
-
description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual.
|
23 |
)
|
24 |
|
25 |
-
# Khởi
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
import gradio as gr
|
4 |
+
from fastapi import FastAPI, APIRouter, Request
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
+
import uvicorn
|
8 |
+
import requests
|
9 |
+
|
10 |
+
# Khởi tạo FastAPI
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
# Thêm middleware CORS để cho phép yêu cầu từ Gradio
|
14 |
+
app.add_middleware(
|
15 |
+
CORSMiddleware,
|
16 |
+
allow_origins=["*"], # Cho phép tất cả các nguồn
|
17 |
+
allow_credentials=True,
|
18 |
+
allow_methods=["*"],
|
19 |
+
allow_headers=["*"],
|
20 |
+
)
|
21 |
|
22 |
# Tải mô hình
|
23 |
+
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base')
|
|
|
24 |
|
25 |
+
# Định nghĩa mô hình dữ liệu cho yêu cầu
|
26 |
+
class PostEmbeddings(BaseModel):
|
27 |
+
type: Literal['default', 'disease', 'gte'] = Field(default='default')
|
28 |
+
sentences: List[str]
|
29 |
+
|
30 |
+
# Tạo router cho API
|
31 |
+
router = APIRouter(
|
32 |
+
prefix="/retrieval",
|
33 |
+
tags=["retrieval"],
|
34 |
+
responses={404: {"description": "Not found"}},
|
35 |
+
)
|
36 |
+
|
37 |
+
@app.post("/retrieval/embeddings")
|
38 |
+
def post_embeddings(data: PostEmbeddings):
|
39 |
+
embeddings = model.encode(data.sentences)
|
40 |
+
return {
|
41 |
+
'data': {
|
42 |
+
'embeddings': embeddings.tolist(),
|
43 |
+
'type': data.type
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
# Hàm Gradio để gọi API FastAPI
|
48 |
+
def call_api(sentences: List[str]):
|
49 |
+
response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences})
|
50 |
+
return response.json()["data"]
|
51 |
|
52 |
# Tạo giao diện Gradio
|
53 |
demo = gr.Interface(
|
54 |
+
fn=call_api,
|
55 |
+
inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
|
56 |
+
outputs=gr.JSON(label="Kết quả mã hóa"),
|
57 |
title="Mô hình GTE Multilingual",
|
58 |
+
description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual."
|
59 |
)
|
60 |
|
61 |
+
# Khởi động server
|
62 |
+
if __name__ == "__main__":
|
63 |
+
import threading
|
64 |
+
|
65 |
+
# Khởi động FastAPI trong một thread riêng
|
66 |
+
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
|
67 |
+
|
68 |
+
# Khởi động Gradio
|
69 |
+
demo.launch()
|