AurelioAguirre commited on
Commit
d0b5a4b
·
1 Parent(s): 799409f

Adding more routes

Browse files
Files changed (3) hide show
  1. main/api.py +103 -1
  2. main/config.yaml +2 -1
  3. main/routes.py +20 -1
main/api.py CHANGED
@@ -1,10 +1,11 @@
1
  import httpx
2
- from typing import Optional, AsyncIterator, Dict, Any, Iterator
3
  import logging
4
  import asyncio
5
  from litserve import LitAPI
6
  from pydantic import BaseModel
7
 
 
8
  class GenerationResponse(BaseModel):
9
  generated_text: str
10
 
@@ -62,6 +63,107 @@ class InferenceApi(LitAPI):
62
  response = await self.generate_response(x, **kwargs)
63
  yield response
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def decode_request(self, request: Any, **kwargs) -> str:
66
  """Convert the request payload to input format."""
67
  if isinstance(request, dict) and "prompt" in request:
 
1
  import httpx
2
+ from typing import Optional, AsyncIterator, Dict, Any, Iterator, List
3
  import logging
4
  import asyncio
5
  from litserve import LitAPI
6
  from pydantic import BaseModel
7
 
8
+
9
  class GenerationResponse(BaseModel):
10
  generated_text: str
11
 
 
63
  response = await self.generate_response(x, **kwargs)
64
  yield response
65
 
66
+ async def generate_embedding(self, text: str) -> List[float]:
67
+ """Generate embedding vector from input text."""
68
+ self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
69
+
70
+ try:
71
+ async with await self._get_client() as client:
72
+ response = await client.post(
73
+ self._get_endpoint('embedding'),
74
+ json={"text": text}
75
+ )
76
+ response.raise_for_status()
77
+ data = response.json()
78
+ return data["embedding"]
79
+
80
+ except Exception as e:
81
+ self.logger.error(f"Error in generate_embedding: {str(e)}")
82
+ raise
83
+
84
+ async def check_system_status(self) -> Dict[str, Any]:
85
+ """Check system status of the LLM Server."""
86
+ self.logger.debug("Checking system status...")
87
+
88
+ try:
89
+ async with await self._get_client() as client:
90
+ response = await client.get(
91
+ self._get_endpoint('system_status')
92
+ )
93
+ response.raise_for_status()
94
+ return response.json()
95
+
96
+ except Exception as e:
97
+ self.logger.error(f"Error in check_system_status: {str(e)}")
98
+ raise
99
+
100
+ async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
101
+ """Download model files from the LLM Server."""
102
+ self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}")
103
+
104
+ try:
105
+ async with await self._get_client() as client:
106
+ response = await client.post(
107
+ self._get_endpoint('model_download'),
108
+ params={"model_name": model_name} if model_name else None
109
+ )
110
+ response.raise_for_status()
111
+ return response.json()
112
+
113
+ except Exception as e:
114
+ self.logger.error(f"Error in download_model: {str(e)}")
115
+ raise
116
+
117
+ async def validate_system(self) -> Dict[str, Any]:
118
+ """Validate system configuration and setup."""
119
+ self.logger.debug("Validating system configuration...")
120
+
121
+ try:
122
+ async with await self._get_client() as client:
123
+ response = await client.get(
124
+ self._get_endpoint('system_validate')
125
+ )
126
+ response.raise_for_status()
127
+ return response.json()
128
+
129
+ except Exception as e:
130
+ self.logger.error(f"Error in validate_system: {str(e)}")
131
+ raise
132
+
133
+ async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
134
+ """Initialize specified model or default model."""
135
+ self.logger.debug(f"Initializing model: {model_name or 'default'}")
136
+
137
+ try:
138
+ async with await self._get_client() as client:
139
+ response = await client.post(
140
+ self._get_endpoint('model_initialize'),
141
+ json={"model_name": model_name} if model_name else {}
142
+ )
143
+ response.raise_for_status()
144
+ return response.json()
145
+
146
+ except Exception as e:
147
+ self.logger.error(f"Error in initialize_model: {str(e)}")
148
+ raise
149
+
150
+ async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
151
+ """Initialize embedding model."""
152
+ self.logger.debug(f"Initializing embedding model: {model_name or 'default'}")
153
+
154
+ try:
155
+ async with await self._get_client() as client:
156
+ response = await client.post(
157
+ self._get_endpoint('model_initialize_embedding'),
158
+ json={"model_name": model_name} if model_name else {}
159
+ )
160
+ response.raise_for_status()
161
+ return response.json()
162
+
163
+ except Exception as e:
164
+ self.logger.error(f"Error in initialize_embedding_model: {str(e)}")
165
+ raise
166
+
167
  def decode_request(self, request: Any, **kwargs) -> str:
168
  """Convert the request payload to input format."""
169
  if isinstance(request, dict) and "prompt" in request:
main/config.yaml CHANGED
@@ -15,4 +15,5 @@ llm_server:
15
  system_status: "/system/status"
16
  system_validate: "/system/validate"
17
  model_initialize: "/model/initialize"
18
- model_initialize_embedding: "/model/initialize/embedding"
 
 
15
  system_status: "/system/status"
16
  system_validate: "/system/validate"
17
  model_initialize: "/model/initialize"
18
+ model_initialize_embedding: "/model/initialize/embedding"
19
+ model_download: "/model/download"
main/routes.py CHANGED
@@ -17,7 +17,7 @@ from .schemas import (
17
 
18
  router = APIRouter()
19
  logger = logging.getLogger(__name__)
20
- api = None
21
 
22
  async def init_router(inference_api: InferenceApi):
23
  """Initialize router with an already setup API instance"""
@@ -174,6 +174,25 @@ async def initialize_embedding_model(model_name: Optional[str] = None):
174
  logger.error(f"Error initializing embedding model: {str(e)}")
175
  raise HTTPException(status_code=500, detail=str(e))
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  @router.on_event("shutdown")
178
  async def shutdown_event():
179
  """Clean up resources on shutdown"""
 
17
 
18
  router = APIRouter()
19
  logger = logging.getLogger(__name__)
20
+ api = InferenceApi()
21
 
22
  async def init_router(inference_api: InferenceApi):
23
  """Initialize router with an already setup API instance"""
 
174
  logger.error(f"Error initializing embedding model: {str(e)}")
175
  raise HTTPException(status_code=500, detail=str(e))
176
 
177
+ @router.post("/model/download",
178
+ summary="Download default or specified model",
179
+ description="Downloads model files. Uses default model from config if none specified.")
180
+ async def download_model(model_name: Optional[str] = None):
181
+ """Download model files to local storage"""
182
+ try:
183
+ # Use model name from config if none provided
184
+ model_to_download = model_name or config["model"]["defaults"]["model_name"]
185
+ logger.info(f"Received request to download model: {model_to_download}")
186
+
187
+ result = await api.download_model(model_to_download)
188
+ logger.info(f"Successfully downloaded model: {model_to_download}")
189
+
190
+ return result
191
+
192
+ except Exception as e:
193
+ logger.error(f"Error downloading model: {str(e)}")
194
+ raise HTTPException(status_code=500, detail=str(e))
195
+
196
  @router.on_event("shutdown")
197
  async def shutdown_event():
198
  """Clean up resources on shutdown"""