AurelioAguirre commited on
Commit
73ca5b8
·
1 Parent(s): 30d3c1f

changed to uvicorn setup for HF v11

Browse files
Files changed (2) hide show
  1. main/api.py +135 -101
  2. main/config.yaml +1 -1
main/api.py CHANGED
@@ -2,6 +2,7 @@ import httpx
2
  from typing import Optional, AsyncIterator, Dict, Any, Iterator, List
3
  import logging
4
  import asyncio
 
5
  from litserve import LitAPI
6
  from pydantic import BaseModel
7
 
@@ -40,6 +41,52 @@ class InferenceApi(LitAPI):
40
  endpoint = endpoints.get(endpoint_name, '')
41
  return f"{api_prefix}{endpoint}"
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def predict(self, x: str, **kwargs) -> Iterator[str]:
44
  """Non-async prediction method that yields results."""
45
  loop = asyncio.get_event_loop()
@@ -63,19 +110,71 @@ class InferenceApi(LitAPI):
63
  response = await self.generate_response(x, **kwargs)
64
  yield response
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  async def generate_embedding(self, text: str) -> List[float]:
67
  """Generate embedding vector from input text."""
68
  self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
69
 
70
  try:
71
- async with await self._get_client() as client:
72
- response = await client.post(
73
- self._get_endpoint('embedding'),
74
- json={"text": text}
75
- )
76
- response.raise_for_status()
77
- data = response.json()
78
- return data["embedding"]
79
 
80
  except Exception as e:
81
  self.logger.error(f"Error in generate_embedding: {str(e)}")
@@ -86,12 +185,11 @@ class InferenceApi(LitAPI):
86
  self.logger.debug("Checking system status...")
87
 
88
  try:
89
- async with await self._get_client() as client:
90
- response = await client.get(
91
- self._get_endpoint('system_status')
92
- )
93
- response.raise_for_status()
94
- return response.json()
95
 
96
  except Exception as e:
97
  self.logger.error(f"Error in check_system_status: {str(e)}")
@@ -102,33 +200,27 @@ class InferenceApi(LitAPI):
102
  self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}")
103
 
104
  try:
105
- async with await self._get_client() as client:
106
- response = await client.post(
107
- self._get_endpoint('model_download'),
108
- params={"model_name": model_name} if model_name else None
109
- )
110
- response.raise_for_status()
111
- return response.json()
112
 
113
  except Exception as e:
114
  self.logger.error(f"Error in download_model: {str(e)}")
115
  raise
116
 
117
- except Exception as e:
118
- self.logger.error(f"Error initiating model download: {str(e)}")
119
- raise
120
-
121
  async def validate_system(self) -> Dict[str, Any]:
122
  """Validate system configuration and setup."""
123
  self.logger.debug("Validating system configuration...")
124
 
125
  try:
126
- async with await self._get_client() as client:
127
- response = await client.get(
128
- self._get_endpoint('system_validate')
129
- )
130
- response.raise_for_status()
131
- return response.json()
132
 
133
  except Exception as e:
134
  self.logger.error(f"Error in validate_system: {str(e)}")
@@ -139,13 +231,12 @@ class InferenceApi(LitAPI):
139
  self.logger.debug(f"Initializing model: {model_name or 'default'}")
140
 
141
  try:
142
- async with await self._get_client() as client:
143
- response = await client.post(
144
- self._get_endpoint('model_initialize'),
145
- params={"model_name": model_name} if model_name else None
146
- )
147
- response.raise_for_status()
148
- return response.json()
149
 
150
  except Exception as e:
151
  self.logger.error(f"Error in initialize_model: {str(e)}")
@@ -156,13 +247,12 @@ class InferenceApi(LitAPI):
156
  self.logger.debug(f"Initializing embedding model: {model_name or 'default'}")
157
 
158
  try:
159
- async with await self._get_client() as client:
160
- response = await client.post(
161
- self._get_endpoint('model_initialize_embedding'),
162
- json={"model_name": model_name} if model_name else {}
163
- )
164
- response.raise_for_status()
165
- return response.json()
166
 
167
  except Exception as e:
168
  self.logger.error(f"Error in initialize_embedding_model: {str(e)}")
@@ -184,62 +274,6 @@ class InferenceApi(LitAPI):
184
  except StopIteration:
185
  return {"generated_text": ""}
186
 
187
- async def generate_response(
188
- self,
189
- prompt: str,
190
- system_message: Optional[str] = None,
191
- max_new_tokens: Optional[int] = None
192
- ) -> str:
193
- """Generate a complete response by forwarding the request to the LLM Server."""
194
- self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
195
-
196
- try:
197
- async with await self._get_client() as client:
198
- response = await client.post(
199
- self._get_endpoint('generate'),
200
- json={
201
- "prompt": prompt,
202
- "system_message": system_message,
203
- "max_new_tokens": max_new_tokens
204
- }
205
- )
206
- response.raise_for_status()
207
- data = response.json()
208
- return data["generated_text"]
209
-
210
- except Exception as e:
211
- self.logger.error(f"Error in generate_response: {str(e)}")
212
- raise
213
-
214
- async def generate_stream(
215
- self,
216
- prompt: str,
217
- system_message: Optional[str] = None,
218
- max_new_tokens: Optional[int] = None
219
- ) -> AsyncIterator[str]:
220
- """Generate a streaming response by forwarding the request to the LLM Server."""
221
- self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
222
-
223
- try:
224
- client = await self._get_client()
225
- async with client.stream(
226
- "POST",
227
- self._get_endpoint('generate_stream'),
228
- json={
229
- "prompt": prompt,
230
- "system_message": system_message,
231
- "max_new_tokens": max_new_tokens
232
- }
233
- ) as response:
234
- response.raise_for_status()
235
- async for chunk in response.aiter_text():
236
- yield chunk
237
- await client.aclose()
238
-
239
- except Exception as e:
240
- self.logger.error(f"Error in generate_stream: {str(e)}")
241
- raise
242
-
243
  async def cleanup(self):
244
  """Cleanup method - no longer needed as clients are created per-request"""
245
  pass
 
2
  from typing import Optional, AsyncIterator, Dict, Any, Iterator, List
3
  import logging
4
  import asyncio
5
+ import os
6
  from litserve import LitAPI
7
  from pydantic import BaseModel
8
 
 
41
  endpoint = endpoints.get(endpoint_name, '')
42
  return f"{api_prefix}{endpoint}"
43
 
44
+ async def _make_request(
45
+ self,
46
+ method: str,
47
+ endpoint: str,
48
+ *,
49
+ params: Optional[Dict[str, Any]] = None,
50
+ json: Optional[Dict[str, Any]] = None,
51
+ stream: bool = False
52
+ ) -> Any:
53
+ """Make an authenticated request to the LLM Server.
54
+
55
+ Args:
56
+ method: HTTP method ('GET' or 'POST')
57
+ endpoint: Endpoint name to get from config
58
+ params: Query parameters
59
+ json: JSON body for POST requests
60
+ stream: Whether to return a streaming response
61
+ """
62
+ access_token = os.environ.get("InferenceAPI")
63
+ headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
64
+
65
+ try:
66
+ async with await self._get_client() as client:
67
+ if stream:
68
+ return await client.stream(
69
+ method,
70
+ self._get_endpoint(endpoint),
71
+ params=params,
72
+ json=json,
73
+ headers=headers
74
+ )
75
+ else:
76
+ response = await client.request(
77
+ method,
78
+ self._get_endpoint(endpoint),
79
+ params=params,
80
+ json=json,
81
+ headers=headers
82
+ )
83
+ response.raise_for_status()
84
+ return response
85
+
86
+ except Exception as e:
87
+ self.logger.error(f"Error in request to {endpoint}: {str(e)}")
88
+ raise
89
+
90
  def predict(self, x: str, **kwargs) -> Iterator[str]:
91
  """Non-async prediction method that yields results."""
92
  loop = asyncio.get_event_loop()
 
110
  response = await self.generate_response(x, **kwargs)
111
  yield response
112
 
113
+ async def generate_response(
114
+ self,
115
+ prompt: str,
116
+ system_message: Optional[str] = None,
117
+ max_new_tokens: Optional[int] = None
118
+ ) -> str:
119
+ """Generate a complete response by forwarding the request to the LLM Server."""
120
+ self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
121
+
122
+ try:
123
+ response = await self._make_request(
124
+ "POST",
125
+ "generate",
126
+ json={
127
+ "prompt": prompt,
128
+ "system_message": system_message,
129
+ "max_new_tokens": max_new_tokens
130
+ }
131
+ )
132
+ data = response.json()
133
+ return data["generated_text"]
134
+
135
+ except Exception as e:
136
+ self.logger.error(f"Error in generate_response: {str(e)}")
137
+ raise
138
+
139
+ async def generate_stream(
140
+ self,
141
+ prompt: str,
142
+ system_message: Optional[str] = None,
143
+ max_new_tokens: Optional[int] = None
144
+ ) -> AsyncIterator[str]:
145
+ """Generate a streaming response by forwarding the request to the LLM Server."""
146
+ self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
147
+
148
+ try:
149
+ async with await self._make_request(
150
+ "POST",
151
+ "generate_stream",
152
+ json={
153
+ "prompt": prompt,
154
+ "system_message": system_message,
155
+ "max_new_tokens": max_new_tokens
156
+ },
157
+ stream=True
158
+ ) as response:
159
+ async for chunk in response.aiter_text():
160
+ yield chunk
161
+
162
+ except Exception as e:
163
+ self.logger.error(f"Error in generate_stream: {str(e)}")
164
+ raise
165
+
166
  async def generate_embedding(self, text: str) -> List[float]:
167
  """Generate embedding vector from input text."""
168
  self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
169
 
170
  try:
171
+ response = await self._make_request(
172
+ "POST",
173
+ "embedding",
174
+ json={"text": text}
175
+ )
176
+ data = response.json()
177
+ return data["embedding"]
 
178
 
179
  except Exception as e:
180
  self.logger.error(f"Error in generate_embedding: {str(e)}")
 
185
  self.logger.debug("Checking system status...")
186
 
187
  try:
188
+ response = await self._make_request(
189
+ "GET",
190
+ "system_status"
191
+ )
192
+ return response.json()
 
193
 
194
  except Exception as e:
195
  self.logger.error(f"Error in check_system_status: {str(e)}")
 
200
  self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}")
201
 
202
  try:
203
+ response = await self._make_request(
204
+ "POST",
205
+ "model_download",
206
+ params={"model_name": model_name} if model_name else None
207
+ )
208
+ return response.json()
 
209
 
210
  except Exception as e:
211
  self.logger.error(f"Error in download_model: {str(e)}")
212
  raise
213
 
 
 
 
 
214
  async def validate_system(self) -> Dict[str, Any]:
215
  """Validate system configuration and setup."""
216
  self.logger.debug("Validating system configuration...")
217
 
218
  try:
219
+ response = await self._make_request(
220
+ "GET",
221
+ "system_validate"
222
+ )
223
+ return response.json()
 
224
 
225
  except Exception as e:
226
  self.logger.error(f"Error in validate_system: {str(e)}")
 
231
  self.logger.debug(f"Initializing model: {model_name or 'default'}")
232
 
233
  try:
234
+ response = await self._make_request(
235
+ "POST",
236
+ "model_initialize",
237
+ params={"model_name": model_name} if model_name else None
238
+ )
239
+ return response.json()
 
240
 
241
  except Exception as e:
242
  self.logger.error(f"Error in initialize_model: {str(e)}")
 
247
  self.logger.debug(f"Initializing embedding model: {model_name or 'default'}")
248
 
249
  try:
250
+ response = await self._make_request(
251
+ "POST",
252
+ "model_initialize_embedding",
253
+ json={"model_name": model_name} if model_name else {}
254
+ )
255
+ return response.json()
 
256
 
257
  except Exception as e:
258
  self.logger.error(f"Error in initialize_embedding_model: {str(e)}")
 
274
  except StopIteration:
275
  return {"generated_text": ""}
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  async def cleanup(self):
278
  """Cleanup method - no longer needed as clients are created per-request"""
279
  pass
main/config.yaml CHANGED
@@ -5,7 +5,7 @@ server:
5
  max_batch_size: 1
6
 
7
  llm_server:
8
- base_url: "https://teamgenki-llmserver.hf.space:7860"
9
  timeout: 60.0
10
  api_prefix: "/api/v1" # This will be used for route prefixing
11
  endpoints:
 
5
  max_batch_size: 1
6
 
7
  llm_server:
8
+ base_url: "https://teamgenki-llmserver.hf.space:7860" # The base URL of the LLM server
9
  timeout: 60.0
10
  api_prefix: "/api/v1" # This will be used for route prefixing
11
  endpoints: