AurelioAguirre commited on
Commit
c6b21e3
·
1 Parent(s): 1bcc710

changed to uvicorn setup for HF

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -3
  2. main/main.py +32 -14
Dockerfile CHANGED
@@ -17,6 +17,4 @@ COPY --chown=user main/ /app/main
17
  EXPOSE 7860
18
 
19
  # Command to run the application
20
- #CMD ["uvicorn", "main.main:app", "--host", "0.0.0.0", "--port", "7860"]
21
-
22
- CMD ["python", "-m", "main.main"]
 
17
  EXPOSE 7860
18
 
19
  # Command to run the application
20
+ CMD ["uvicorn", "main.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
main/main.py CHANGED
@@ -6,6 +6,7 @@ import yaml
6
  import logging
7
  import asyncio
8
  from pathlib import Path
 
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from .routes import router, init_router
11
  from .api import InferenceApi
@@ -24,8 +25,8 @@ def load_config():
24
  with open(config_path) as f:
25
  return yaml.safe_load(f)
26
 
27
- async def async_main():
28
- """Create and configure the application instance asynchronously."""
29
  logger = setup_logging()
30
 
31
  try:
@@ -36,7 +37,7 @@ async def async_main():
36
  # Initialize API with config
37
  api = InferenceApi(config)
38
 
39
- # Initialize router with the already setup API instance
40
  await init_router(api)
41
 
42
  # Create LitServer instance with config
@@ -47,8 +48,11 @@ async def async_main():
47
  track_requests=True
48
  )
49
 
 
 
 
50
  # Add CORS middleware
51
- server.app.add_middleware(
52
  CORSMiddleware,
53
  allow_origins=["*"],
54
  allow_credentials=True,
@@ -58,22 +62,36 @@ async def async_main():
58
 
59
  # Add routes with configured prefix
60
  api_prefix = config.get('llm_server', {}).get('api_prefix', '/api/v1')
61
- server.app.include_router(router, prefix=api_prefix)
62
-
63
- # Get configured port
64
- port = server_config.get('port', 8001)
65
- host = server_config.get('host', '0.0.0.0')
66
 
67
- # Run server
68
- server.run(host=host, port=port)
69
 
70
  except Exception as e:
71
- logger.error(f"Server initialization failed: {str(e)}")
72
  raise
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def main():
75
- """Entry point that runs the async main"""
76
- asyncio.run(async_main())
77
 
78
  if __name__ == "__main__":
79
  main()
 
6
  import logging
7
  import asyncio
8
  from pathlib import Path
9
+ from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from .routes import router, init_router
12
  from .api import InferenceApi
 
25
  with open(config_path) as f:
26
  return yaml.safe_load(f)
27
 
28
+ async def init_app() -> tuple[FastAPI, InferenceApi, dict]:
29
+ """Initialize and configure the FastAPI application."""
30
  logger = setup_logging()
31
 
32
  try:
 
37
  # Initialize API with config
38
  api = InferenceApi(config)
39
 
40
+ # Initialize router with the API instance
41
  await init_router(api)
42
 
43
  # Create LitServer instance with config
 
48
  track_requests=True
49
  )
50
 
51
+ # Get the FastAPI app from the LitServer
52
+ app = server.app
53
+
54
  # Add CORS middleware
55
+ app.add_middleware(
56
  CORSMiddleware,
57
  allow_origins=["*"],
58
  allow_credentials=True,
 
62
 
63
  # Add routes with configured prefix
64
  api_prefix = config.get('llm_server', {}).get('api_prefix', '/api/v1')
65
+ app.include_router(router, prefix=api_prefix)
 
 
 
 
66
 
67
+ return app, api, config
 
68
 
69
  except Exception as e:
70
+ logger.error(f"Application initialization failed: {str(e)}")
71
  raise
72
 
73
+ # Create the FastAPI app instance for uvicorn
74
+ app, api_instance, config_dict = asyncio.get_event_loop().run_until_complete(init_app())
75
+
76
+ async def run_server():
77
+ """Run the server directly (not through uvicorn)"""
78
+ server_config = config_dict.get('server', {})
79
+ port = server_config.get('port', 8001)
80
+ host = server_config.get('host', '0.0.0.0')
81
+
82
+ # Create LitServer instance with all required parameters
83
+ server = ls.LitServer(
84
+ api_instance,
85
+ timeout=server_config.get('timeout', 60),
86
+ max_batch_size=server_config.get('max_batch_size', 1),
87
+ track_requests=True
88
+ )
89
+
90
+ server.run(host=host, port=port)
91
+
92
  def main():
93
+ """Entry point that runs the server directly"""
94
+ asyncio.run(run_server())
95
 
96
  if __name__ == "__main__":
97
  main()