thecuong commited on
Commit
ebb0dae
·
1 Parent(s): 466646a

feat: update fast API

Browse files
Files changed (1) hide show
  1. app.py +58 -15
app.py CHANGED
@@ -1,26 +1,69 @@
 
 
1
  import gradio as gr
 
 
2
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Tải mô hình
5
- model = SentenceTransformer(model_name_or_path='Alibaba-NLP/gte-multilingual-base',
6
- trust_remote_code=True)
7
 
8
- def gte_model(sentences: list):
9
- try:
10
- # hóa các câu
11
- embeddings = model.encode(sentences)
12
- return embeddings.tolist() # Chuyển đổi numpy array sang danh sách
13
- except Exception as e:
14
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Tạo giao diện Gradio
17
  demo = gr.Interface(
18
- fn=gte_model,
19
- inputs=gr.inputs.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
20
- outputs=gr.outputs.JSON(label="Kết quả mã hóa"),
21
  title="Mô hình GTE Multilingual",
22
- description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual. Kết quả sẽ được trả về dưới dạng danh sách mã hóa."
23
  )
24
 
25
- # Khởi chạy giao diện
26
- demo.launch()
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal
2
+ from pydantic import BaseModel, Field
3
  import gradio as gr
4
+ from fastapi import FastAPI, APIRouter, Request
5
+ from fastapi.middleware.cors import CORSMiddleware
6
  from sentence_transformers import SentenceTransformer
7
+ import uvicorn
8
+ import requests
9
+
10
+ # Khởi tạo FastAPI
11
+ app = FastAPI()
12
+
13
+ # Thêm middleware CORS để cho phép yêu cầu từ Gradio
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"], # Cho phép tất cả các nguồn
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
 
22
  # Tải mô hình
23
+ model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base')
 
24
 
25
+ # Định nghĩa mô hình dữ liệu cho yêu cầu
26
+ class PostEmbeddings(BaseModel):
27
+ type: Literal['default', 'disease', 'gte'] = Field(default='default')
28
+ sentences: List[str]
29
+
30
+ # Tạo router cho API
31
+ router = APIRouter(
32
+ prefix="/retrieval",
33
+ tags=["retrieval"],
34
+ responses={404: {"description": "Not found"}},
35
+ )
36
+
37
+ @app.post("/retrieval/embeddings")
38
+ def post_embeddings(data: PostEmbeddings):
39
+ embeddings = model.encode(data.sentences)
40
+ return {
41
+ 'data': {
42
+ 'embeddings': embeddings.tolist(),
43
+ 'type': data.type
44
+ }
45
+ }
46
+
47
+ # Hàm Gradio để gọi API FastAPI
48
+ def call_api(sentences: List[str]):
49
+ response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences})
50
+ return response.json()["data"]
51
 
52
  # Tạo giao diện Gradio
53
  demo = gr.Interface(
54
+ fn=call_api,
55
+ inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
56
+ outputs=gr.JSON(label="Kết quả mã hóa"),
57
  title="Mô hình GTE Multilingual",
58
+ description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual."
59
  )
60
 
61
+ # Khởi động server
62
+ if __name__ == "__main__":
63
+ import threading
64
+
65
+ # Khởi động FastAPI trong một thread riêng
66
+ threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
67
+
68
+ # Khởi động Gradio
69
+ demo.launch()