AurelioAguirre commited on
Commit
7e3820c
·
1 Parent(s): 08548e7

WIP adapter.py

Browse files
Files changed (1) hide show
  1. main/adapter.py +346 -0
main/adapter.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import logging
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional, Dict, Any, AsyncIterator, List
5
+
6
+ class LLMAdapter(ABC):
7
+ """Abstract base class for LLM adapters."""
8
+
9
+ @abstractmethod
10
+ async def generate_response(
11
+ self,
12
+ prompt: str,
13
+ system_message: Optional[str] = None,
14
+ max_new_tokens: Optional[int] = None
15
+ ) -> str:
16
+ """Generate a complete response from the LLM."""
17
+ pass
18
+
19
+ @abstractmethod
20
+ async def generate_stream(
21
+ self,
22
+ prompt: str,
23
+ system_message: Optional[str] = None,
24
+ max_new_tokens: Optional[int] = None
25
+ ) -> AsyncIterator[str]:
26
+ """Generate a streaming response from the LLM."""
27
+ pass
28
+
29
+ @abstractmethod
30
+ async def generate_embedding(self, text: str) -> List[float]:
31
+ """Generate embedding vector from input text."""
32
+ pass
33
+
34
+ @abstractmethod
35
+ async def check_system_status(self) -> Dict[str, Any]:
36
+ """Check system status of the LLM Server."""
37
+ pass
38
+
39
+ @abstractmethod
40
+ async def validate_system(self) -> Dict[str, Any]:
41
+ """Validate system configuration and setup."""
42
+ pass
43
+
44
+ @abstractmethod
45
+ async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
46
+ """Initialize specified model or default model."""
47
+ pass
48
+
49
+ @abstractmethod
50
+ async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
51
+ """Initialize embedding model."""
52
+ pass
53
+
54
+ @abstractmethod
55
+ async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
56
+ """Download model files."""
57
+ pass
58
+
59
+ @abstractmethod
60
+ async def cleanup(self):
61
+ """Cleanup resources."""
62
+ pass
63
+
64
+
65
+ class HTTPLLMAdapter(LLMAdapter):
66
+ """HTTP adapter for connecting to LLM services over HTTP."""
67
+
68
+ def __init__(self, config: Dict[str, Any]):
69
+ """Initialize the HTTP LLM Adapter with configuration."""
70
+ self.logger = logging.getLogger(__name__)
71
+ self.logger.info("Initializing HTTP LLM Adapter")
72
+ self.config = config
73
+ self.llm_config = config.get('llm_server', {})
74
+
75
+ async def _get_client(self):
76
+ """Get or create HTTP client as needed"""
77
+ host = self.llm_config.get('host', 'localhost')
78
+ port = self.llm_config.get('port', 8002)
79
+
80
+ # Construct base URL, omitting port for HF spaces
81
+ if 'hf.space' in host:
82
+ base_url = f"https://{host}"
83
+ else:
84
+ base_url = f"http://{host}:{port}"
85
+
86
+ return httpx.AsyncClient(
87
+ base_url=base_url,
88
+ timeout=float(self.llm_config.get('timeout', 60.0))
89
+ )
90
+
91
+ def _get_endpoint(self, endpoint_name: str) -> str:
92
+ """Get full endpoint path including prefix"""
93
+ endpoints = self.llm_config.get('endpoints', {})
94
+ api_prefix = self.llm_config.get('api_prefix', '')
95
+ endpoint = endpoints.get(endpoint_name, '')
96
+ return f"{api_prefix}{endpoint}"
97
+
98
+ async def _make_request(
99
+ self,
100
+ method: str,
101
+ endpoint: str,
102
+ *,
103
+ params: Optional[Dict[str, Any]] = None,
104
+ json: Optional[Dict[str, Any]] = None,
105
+ stream: bool = False
106
+ ) -> Any:
107
+ """Make an authenticated request to the LLM Server."""
108
+ base_url = self.llm_config.get('host', 'http://localhost:8001')
109
+ full_endpoint = f"{base_url.rstrip('/')}/{self._get_endpoint(endpoint).lstrip('/')}"
110
+
111
+ try:
112
+ self.logger.info(f"Making {method} request to: {full_endpoint}")
113
+ # Create client outside the with block for streaming
114
+ client = await self._get_client()
115
+
116
+ if stream:
117
+ # For streaming, return both client and response context managers
118
+ return client, client.stream(
119
+ method,
120
+ self._get_endpoint(endpoint),
121
+ params=params,
122
+ json=json
123
+ )
124
+ else:
125
+ # For non-streaming, use context manager
126
+ async with client as c:
127
+ response = await c.request(
128
+ method,
129
+ self._get_endpoint(endpoint),
130
+ params=params,
131
+ json=json
132
+ )
133
+ response.raise_for_status()
134
+ return response
135
+
136
+ except Exception as e:
137
+ self.logger.error(f"Error in request to {full_endpoint}: {str(e)}")
138
+ raise
139
+
140
+ async def generate_response(
141
+ self,
142
+ prompt: str,
143
+ system_message: Optional[str] = None,
144
+ max_new_tokens: Optional[int] = None
145
+ ) -> str:
146
+ """Generate a complete response by forwarding the request to the LLM Server."""
147
+ self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
148
+
149
+ try:
150
+ response = await self._make_request(
151
+ "POST",
152
+ "generate",
153
+ json={
154
+ "prompt": prompt,
155
+ "system_message": system_message,
156
+ "max_new_tokens": max_new_tokens
157
+ }
158
+ )
159
+ data = response.json()
160
+ return data["generated_text"]
161
+
162
+ except Exception as e:
163
+ self.logger.error(f"Error in generate_response: {str(e)}")
164
+ raise
165
+
166
+ async def generate_stream(
167
+ self,
168
+ prompt: str,
169
+ system_message: Optional[str] = None,
170
+ max_new_tokens: Optional[int] = None
171
+ ) -> AsyncIterator[str]:
172
+ """Generate a streaming response by forwarding the request to the LLM Server."""
173
+ self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
174
+
175
+ try:
176
+ client, stream_cm = await self._make_request(
177
+ "POST",
178
+ "generate_stream",
179
+ json={
180
+ "prompt": prompt,
181
+ "system_message": system_message,
182
+ "max_new_tokens": max_new_tokens
183
+ },
184
+ stream=True
185
+ )
186
+
187
+ async with client:
188
+ async with stream_cm as response:
189
+ async for chunk in response.aiter_text():
190
+ yield chunk
191
+
192
+ except Exception as e:
193
+ self.logger.error(f"Error in generate_stream: {str(e)}")
194
+ raise
195
+
196
+ async def generate_embedding(self, text: str) -> List[float]:
197
+ """Generate embedding vector from input text."""
198
+ self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
199
+
200
+ try:
201
+ response = await self._make_request(
202
+ "POST",
203
+ "embedding",
204
+ json={"text": text}
205
+ )
206
+ data = response.json()
207
+ return data["embedding"]
208
+
209
+ except Exception as e:
210
+ self.logger.error(f"Error in generate_embedding: {str(e)}")
211
+ raise
212
+
213
+ async def check_system_status(self) -> Dict[str, Any]:
214
+ """Check system status of the LLM Server."""
215
+ self.logger.debug("Checking system status...")
216
+
217
+ try:
218
+ response = await self._make_request(
219
+ "GET",
220
+ "system_status"
221
+ )
222
+ return response.json()
223
+
224
+ except Exception as e:
225
+ self.logger.error(f"Error in check_system_status: {str(e)}")
226
+ raise
227
+
228
+ async def validate_system(self) -> Dict[str, Any]:
229
+ """Validate system configuration and setup."""
230
+ self.logger.debug("Validating system configuration...")
231
+
232
+ try:
233
+ response = await self._make_request(
234
+ "GET",
235
+ "system_validate"
236
+ )
237
+ return response.json()
238
+
239
+ except Exception as e:
240
+ self.logger.error(f"Error in validate_system: {str(e)}")
241
+ raise
242
+
243
+ async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
244
+ """Initialize specified model or default model."""
245
+ self.logger.debug(f"Initializing model: {model_name or 'default'}")
246
+
247
+ try:
248
+ response = await self._make_request(
249
+ "POST",
250
+ "model_initialize",
251
+ params={"model_name": model_name} if model_name else None
252
+ )
253
+ return response.json()
254
+
255
+ except Exception as e:
256
+ self.logger.error(f"Error in initialize_model: {str(e)}")
257
+ raise
258
+
259
+ async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
260
+ """Initialize embedding model."""
261
+ self.logger.debug(f"Initializing embedding model: {model_name or 'default'}")
262
+
263
+ try:
264
+ response = await self._make_request(
265
+ "POST",
266
+ "model_initialize_embedding",
267
+ json={"model_name": model_name} if model_name else {}
268
+ )
269
+ return response.json()
270
+
271
+ except Exception as e:
272
+ self.logger.error(f"Error in initialize_embedding_model: {str(e)}")
273
+ raise
274
+
275
+ async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
276
+ """Download model files from the LLM Server."""
277
+ self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}")
278
+
279
+ try:
280
+ response = await self._make_request(
281
+ "POST",
282
+ "model_download",
283
+ params={"model_name": model_name} if model_name else None
284
+ )
285
+ return response.json()
286
+
287
+ except Exception as e:
288
+ self.logger.error(f"Error in download_model: {str(e)}")
289
+ raise
290
+
291
+ async def cleanup(self):
292
+ """Cleanup method - no longer needed as clients are created per-request"""
293
+ pass
294
+
295
+
296
+ class OpenAIAdapter(LLMAdapter):
297
+ """Adapter for OpenAI-compatible services (OpenAI, Azure OpenAI, local services with OpenAI API)."""
298
+
299
+ def __init__(self, config: Dict[str, Any]):
300
+ self.logger = logging.getLogger(__name__)
301
+ self.logger.info("Initializing OpenAI Adapter")
302
+ self.config = config
303
+ self.openai_config = config.get('openai', {})
304
+ # Additional OpenAI-specific setup would go here
305
+
306
+ async def generate_response(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> str:
307
+ """OpenAI implementation - would use openai Python client"""
308
+ # Implementation would go here
309
+ pass
310
+
311
+ async def generate_stream(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> AsyncIterator[str]:
312
+ """OpenAI streaming implementation"""
313
+ # Implementation would go here
314
+ async def placeholder_stream():
315
+ yield "Not implemented yet"
316
+ return placeholder_stream()
317
+
318
+ # ... implementations for other methods
319
+
320
+
321
+ class vLLMAdapter(LLMAdapter):
322
+ """Adapter for vLLM services."""
323
+
324
+ def __init__(self, config: Dict[str, Any]):
325
+ self.logger = logging.getLogger(__name__)
326
+ self.logger.info("Initializing vLLM Adapter")
327
+ self.config = config
328
+ self.vllm_config = config.get('vllm', {})
329
+ # Additional vLLM-specific setup would go here
330
+
331
+ # ... implementations for all methods
332
+
333
+
334
+ # Factory function to create the appropriate adapter
335
+ def create_adapter(config: Dict[str, Any]) -> LLMAdapter:
336
+ """Create an adapter instance based on configuration."""
337
+ adapter_type = config.get('adapter', {}).get('type', 'http')
338
+
339
+ if adapter_type == 'http':
340
+ return HTTPLLMAdapter(config)
341
+ elif adapter_type == 'openai':
342
+ return OpenAIAdapter(config)
343
+ elif adapter_type == 'vllm':
344
+ return vLLMAdapter(config)
345
+ else:
346
+ raise ValueError(f"Unknown adapter type: {adapter_type}")