thecuong commited on
Commit
c87a1ad
·
1 Parent(s): 1cfb5d9

feat: update

Browse files
Files changed (1) hide show
  1. app.py +38 -18
app.py CHANGED
@@ -1,45 +1,65 @@
1
  from typing import List, Literal
2
  from pydantic import BaseModel, Field
 
3
  from fastapi import FastAPI, APIRouter, Request
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from sentence_transformers import SentenceTransformer
6
  import uvicorn
 
 
 
7
 
8
- # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
- # CORS middleware
12
  app.add_middleware(
13
  CORSMiddleware,
14
- allow_origins=["*"],
15
  allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
- # Load model
21
- model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
22
 
23
- # Define data model
24
  class PostEmbeddings(BaseModel):
25
  type: Literal['default', 'disease', 'gte'] = Field(default='default')
26
  sentences: List[str]
27
 
28
- # Router for embeddings
29
- router = APIRouter(prefix="/retrieval", tags=["retrieval"])
 
 
 
 
30
 
31
- @router.post('/embeddings')
32
- def post_embeddings(request: Request, data: PostEmbeddings):
33
  embeddings = model.encode(data.sentences)
34
- return {"embeddings": embeddings.tolist()}
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Include router
37
- app.include_router(router)
 
 
38
 
39
- # Define main function to run the app
40
- def main():
41
- uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
42
 
43
- # Run the app if this script is the main module
44
  if __name__ == "__main__":
45
- main()
 
1
  from typing import List, Literal
2
  from pydantic import BaseModel, Field
3
+ import gradio as gr
4
  from fastapi import FastAPI, APIRouter, Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from sentence_transformers import SentenceTransformer
7
  import uvicorn
8
+ import requests
9
+ import asyncio
10
+ import threading
11
 
12
+ # Khởi tạo FastAPI
13
  app = FastAPI()
14
 
15
+ # Thêm middleware CORS để cho phép yêu cầu từ Gradio
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"], # Cho phép tất cả các nguồn
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
+ # Tải mô hình
25
+ model = SentenceTransformer(model_name_or_path='Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
26
 
27
+ # Định nghĩa mô hình dữ liệu cho yêu cầu
28
  class PostEmbeddings(BaseModel):
29
  type: Literal['default', 'disease', 'gte'] = Field(default='default')
30
  sentences: List[str]
31
 
32
+ # Tạo router cho API
33
+ router = APIRouter(
34
+ prefix="/retrieval",
35
+ tags=["retrieval"],
36
+ responses={404: {"description": "Not found"}},
37
+ )
38
 
39
+ @app.post("/retrieval/embeddings")
40
+ def post_embeddings(data: PostEmbeddings):
41
  embeddings = model.encode(data.sentences)
42
+ return {
43
+ 'data': {
44
+ 'embeddings': embeddings.tolist(),
45
+ 'type': data.type
46
+ }
47
+ }
48
+
49
+ # Hàm Gradio để gọi API FastAPI
50
+
51
+
52
+ # async def run_gradio():
53
+ # demo.launch(share=True)
54
 
55
+ async def run_uvicorn():
56
+ config = uvicorn.Config("app:app", host="0.0.0.0", port=8000, reload=True)
57
+ server = uvicorn.Server(config)
58
+ await server.serve()
59
 
60
+ # async def main():
61
+ # await asyncio.gather(run_uvicorn(), run_gradio())
 
62
 
63
+ # Khởi động server
64
  if __name__ == "__main__":
65
+ asyncio.run(run_uvicorn())