thecuong commited on
Commit
1cfb5d9
·
1 Parent(s): 7373e61

feat: update

Browse files
Files changed (1) hide show
  1. app.py +18 -42
app.py CHANGED
@@ -1,69 +1,45 @@
1
  from typing import List, Literal
2
  from pydantic import BaseModel, Field
3
- import gradio as gr
4
  from fastapi import FastAPI, APIRouter, Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from sentence_transformers import SentenceTransformer
7
  import uvicorn
8
- import requests
9
 
10
- # Khởi tạo FastAPI
11
  app = FastAPI()
12
 
13
- # Thêm middleware CORS để cho phép yêu cầu từ Gradio
14
  app.add_middleware(
15
  CORSMiddleware,
16
- allow_origins=["*"], # Cho phép tất cả các nguồn
17
  allow_credentials=True,
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
- # Tải mô hình
23
- model = SentenceTransformer(model_name_or_path='Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
24
 
25
- # Định nghĩa mô hình dữ liệu cho yêu cầu
26
  class PostEmbeddings(BaseModel):
27
  type: Literal['default', 'disease', 'gte'] = Field(default='default')
28
  sentences: List[str]
29
 
30
- # Tạo router cho API
31
- router = APIRouter(
32
- prefix="/retrieval",
33
- tags=["retrieval"],
34
- responses={404: {"description": "Not found"}},
35
- )
36
 
37
- @app.post("/retrieval/embeddings")
38
- def post_embeddings(data: PostEmbeddings):
39
  embeddings = model.encode(data.sentences)
40
- return {
41
- 'data': {
42
- 'embeddings': embeddings.tolist(),
43
- 'type': data.type
44
- }
45
- }
46
 
47
- # Hàm Gradio để gọi API FastAPI
48
- def call_api(sentences: List[str]):
49
- response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences})
50
- return response.json()["data"]
51
 
52
- # Tạo giao diện Gradio
53
- demo = gr.Interface(
54
- fn=call_api,
55
- inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
56
- outputs=gr.JSON(label="Kết quả mã hóa"),
57
- title="Mô hình GTE Multilingual",
58
- description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual."
59
- )
60
 
61
- # Khởi động server
62
  if __name__ == "__main__":
63
- import threading
64
-
65
- # Khởi động FastAPI trong một thread riêng
66
- threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
67
-
68
- # Khởi động Gradio
69
- demo.launch()
 
1
  from typing import List, Literal
2
  from pydantic import BaseModel, Field
 
3
  from fastapi import FastAPI, APIRouter, Request
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from sentence_transformers import SentenceTransformer
6
  import uvicorn
 
7
 
8
+ # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
+ # CORS middleware
12
  app.add_middleware(
13
  CORSMiddleware,
14
+ allow_origins=["*"],
15
  allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
+ # Load model
21
+ model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
22
 
23
+ # Define data model
24
  class PostEmbeddings(BaseModel):
25
  type: Literal['default', 'disease', 'gte'] = Field(default='default')
26
  sentences: List[str]
27
 
28
+ # Router for embeddings
29
+ router = APIRouter(prefix="/retrieval", tags=["retrieval"])
 
 
 
 
30
 
31
+ @router.post('/embeddings')
32
+ def post_embeddings(request: Request, data: PostEmbeddings):
33
  embeddings = model.encode(data.sentences)
34
+ return {"embeddings": embeddings.tolist()}
 
 
 
 
 
35
 
36
+ # Include router
37
+ app.include_router(router)
 
 
38
 
39
+ # Define main function to run the app
40
+ def main():
41
+ uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
 
 
 
 
 
42
 
43
+ # Run the app if this script is the main module
44
  if __name__ == "__main__":
45
+ main()