Chris4K commited on
Commit
aee91aa
·
verified ·
1 Parent(s): 919c57c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -546
app.py CHANGED
@@ -19,552 +19,6 @@ except Exception as e:
19
 
20
 
21
 
22
- # model_manager.py
23
- import torch
24
- from transformers import AutoModelForCausalLM, AutoTokenizer
25
- from llama_cpp import Llama
26
- from typing import Optional, Dict
27
- import logging
28
- from functools import lru_cache
29
- from config.config import GenerationConfig, ModelConfig
30
-
31
-
32
- class ModelManager:
33
- def __init__(self, device: Optional[str] = None):
34
- self.logger = logging.getLogger(__name__)
35
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
36
- self.models: Dict[str, Any] = {}
37
- self.tokenizers: Dict[str, Any] = {}
38
-
39
- @observe()
40
- def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None:
41
- """Load a model with specified configuration."""
42
- try:
43
- ##could be differnt models, so we can use a factory pattern to load the correct model - textgen, llama, gguf, text2video, text2image etc.
44
- if model_type == "llama":
45
- self.tokenizers[model_id] = AutoTokenizer.from_pretrained(
46
- model_path,
47
- padding_side='left',
48
- trust_remote_code=True,
49
- **config.tokenizer_kwargs
50
- )
51
- if self.tokenizers[model_id].pad_token is None:
52
- self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token
53
-
54
- self.models[model_id] = AutoModelForCausalLM.from_pretrained(
55
- model_path,
56
- device_map="auto",
57
- trust_remote_code=True,
58
- **config.model_kwargs
59
- )
60
- elif model_type == "gguf":
61
- #TODO load the model first from the cache, if not found load the model and save it in the cache
62
- #from huggingface_hub import hf_hub_download
63
- #prm_model_path = hf_hub_download(
64
- # repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF",
65
- # filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf"
66
- #)
67
-
68
-
69
- self.models[model_id] = self._load_quantized_model(
70
- model_path,
71
- **config.quantization_kwargs
72
- )
73
- except Exception as e:
74
- self.logger.error(f"Failed to load model {model_id}: {str(e)}")
75
- raise
76
-
77
- @observe()
78
- def unload_model(self, model_id: str) -> None:
79
- """Unload a model and free resources."""
80
- if model_id in self.models:
81
- del self.models[model_id]
82
- if model_id in self.tokenizers:
83
- del self.tokenizers[model_id]
84
- torch.cuda.empty_cache()
85
-
86
- def _load_quantized_model(self, model_path: str, **kwargs) -> Llama:
87
- """Load a quantized GGUF model."""
88
- try:
89
- n_gpu_layers = -1 if torch.cuda.is_available() else 0
90
- model = Llama(
91
- model_path=model_path,
92
- n_ctx=kwargs.get('n_ctx', 2048),
93
- n_batch=kwargs.get('n_batch', 512),
94
- n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers),
95
- verbose=kwargs.get('verbose', False)
96
- )
97
- return model
98
- except Exception as e:
99
- self.logger.error(f"Failed to load GGUF model: {str(e)}")
100
- raise
101
-
102
-
103
- # cache.py
104
- from functools import lru_cache
105
- from typing import Tuple, Any
106
-
107
- # TODO explain howto use the cache
108
- class ResponseCache:
109
- def __init__(self, cache_size: int = 1000):
110
- self.cache_size = cache_size
111
- self._initialize_cache()
112
-
113
- def _initialize_cache(self):
114
- @lru_cache(maxsize=self.cache_size)
115
- def cached_response(prompt: str, config_hash: str) -> Tuple[str, float]:
116
- pass
117
- self.get_cached_response = cached_response
118
-
119
- def cache_response(self, prompt: str, config: GenerationConfig, response: str, score: float) -> None:
120
- config_hash = hash(str(config.__dict__))
121
- self.get_cached_response(prompt, str(config_hash))
122
-
123
- def get_response(self, prompt: str, config: GenerationConfig) -> Optional[Tuple[str, float]]:
124
- config_hash = hash(str(config.__dict__))
125
- return self.get_cached_response(prompt, str(config_hash))
126
-
127
-
128
- # batch_processor.py
129
- from typing import List, Dict
130
- import asyncio
131
-
132
- #TODO explain how to use the batch processor
133
- class BatchProcessor:
134
- def __init__(self, max_batch_size: int = 32, max_wait_time: float = 0.1):
135
- self.max_batch_size = max_batch_size
136
- self.max_wait_time = max_wait_time
137
- self.pending_requests: List[Dict] = []
138
- self.lock = asyncio.Lock()
139
-
140
- async def add_request(self, request: Dict) -> Any:
141
- async with self.lock:
142
- self.pending_requests.append(request)
143
- if len(self.pending_requests) >= self.max_batch_size:
144
- return await self._process_batch()
145
- else:
146
- await asyncio.sleep(self.max_wait_time)
147
- if self.pending_requests:
148
- return await self._process_batch()
149
-
150
- async def _process_batch(self) -> List[Any]:
151
- batch = self.pending_requests[:self.max_batch_size]
152
- self.pending_requests = self.pending_requests[self.max_batch_size:]
153
- # TODO implement the batch processing logic
154
- return batch
155
-
156
-
157
-
158
- # base_generator.py
159
- from abc import ABC, abstractmethod
160
- from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
161
- from dataclasses import dataclass
162
- from logging import getLogger
163
-
164
-
165
- from config.config import GenerationConfig, ModelConfig
166
-
167
- class BaseGenerator(ABC):
168
- """Base class for all generator implementations."""
169
-
170
- def __init__(
171
- self,
172
- model_name: str,
173
- device: Optional[str] = None,
174
- default_generation_config: Optional[GenerationConfig] = None,
175
- model_config: Optional[ModelConfig] = None,
176
- cache_size: int = 1000,
177
- max_batch_size: int = 32
178
- ):
179
- self.logger = getLogger(__name__)
180
- self.model_manager = ModelManager(device)
181
- self.cache = ResponseCache(cache_size)
182
- self.batch_processor = BatchProcessor(max_batch_size)
183
- self.health_check = HealthCheck()
184
- # self.tokenizer = self.model_manager.tokenizers[model_name]
185
- #self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
186
- self.default_config = default_generation_config or GenerationConfig()
187
- self.model_config = model_config or ModelConfig()
188
-
189
- @abstractmethod
190
- async def generate_stream(
191
- self,
192
- prompt: str,
193
- config: Optional[GenerationConfig] = None
194
- ) -> AsyncGenerator[str, None]:
195
- pass
196
-
197
- @abstractmethod
198
- def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
199
- pass
200
-
201
- @abstractmethod
202
- def generate(
203
- self,
204
- prompt: str,
205
- model_kwargs: Dict[str, Any],
206
- strategy: str = "default",
207
- **kwargs
208
- ) -> str:
209
- pass
210
-
211
-
212
- # strategy.py
213
- #TODO UPDATE Paths
214
- from abc import ABC, abstractmethod
215
- from typing import List, Tuple
216
-
217
- @observe()
218
- class GenerationStrategy(ABC):
219
- """Base class for generation strategies."""
220
-
221
- @abstractmethod
222
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
223
- pass
224
-
225
-
226
- class DefaultStrategy(GenerationStrategy):
227
-
228
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
229
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
230
- output = generator.model.generate(input_ids, **model_kwargs)
231
- return generator.tokenizer.decode(output[0], skip_special_tokens=True)
232
-
233
- @observe()
234
- class MajorityVotingStrategy(GenerationStrategy):
235
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
236
- outputs = []
237
- for _ in range(num_samples):
238
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
239
- output = generator.model.generate(input_ids, **model_kwargs)
240
- outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
241
- return max(set(outputs), key=outputs.count)
242
-
243
- @observe()
244
- class BestOfN(GenerationStrategy):
245
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
246
- scored_outputs = []
247
- for _ in range(num_samples):
248
- input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
249
- output = self.llama_model.generate(input_ids, **model_kwargs)
250
- response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
251
- score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
252
- scored_outputs.append((response, score))
253
- return max(scored_outputs, key=lambda x: x[1])[0]
254
-
255
- @observe()
256
- class BeamSearch(GenerationStrategy):
257
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
258
- input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
259
- outputs = self.llama_model.generate(
260
- input_ids,
261
- num_beams=num_samples,
262
- num_return_sequences=num_samples,
263
- **model_kwargs
264
- )
265
- return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
266
-
267
- @observe()
268
- class DVT(GenerationStrategy):
269
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
270
- results = []
271
- for _ in range(breadth):
272
- input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
273
- output = self.llama_model.generate(input_ids, **model_kwargs)
274
- response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
275
- score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
276
- results.append((response, score))
277
-
278
- for _ in range(depth - 1):
279
- best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
280
- for response, _ in best_responses:
281
- input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device)
282
- output = self.llama_model.generate(input_ids, **model_kwargs)
283
- extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
284
- score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item()
285
- results.append((extended_response, score))
286
- return max(results, key=lambda x: x[1])[0]
287
-
288
- @observe()
289
- class COT(GenerationStrategy):
290
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
291
- #TODO implement the chain of thought strategy
292
-
293
- return "Not implemented yet"
294
-
295
- @observe()
296
- class ReAct(GenerationStrategy):
297
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
298
- #TODO implement the ReAct framework
299
- return "Not implemented yet"
300
- # Add other strategy implementations...
301
-
302
- # prompt_builder.py
303
- from typing import Protocol, List, Tuple
304
- from transformers import AutoTokenizer
305
-
306
- @observe()
307
- class PromptTemplate(Protocol):
308
- """Protocol for prompt templates."""
309
- def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
310
- pass
311
-
312
- @observe()
313
- class LlamaPromptTemplate:
314
- def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
315
- system_message = f"Please assist based on the following context: {context}"
316
- prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
317
-
318
- for user_msg, assistant_msg in chat_history[-max_history_turns:]:
319
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
320
- prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
321
-
322
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
323
- prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
324
- return prompt
325
-
326
- @observe()
327
- class TransformersPromptTemplate:
328
- def __init__(self, model_path: str):
329
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
330
-
331
- def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
332
- messages = [
333
- {
334
- "role": "system",
335
- "content": f"Please assist based on the following context: {context}",
336
- }
337
- ]
338
-
339
- for user_msg, assistant_msg in chat_history:
340
- messages.extend([
341
- {"role": "user", "content": user_msg},
342
- {"role": "assistant", "content": assistant_msg}
343
- ])
344
-
345
- messages.append({"role": "user", "content": user_input})
346
-
347
- tokenized_chat = self.tokenizer.apply_chat_template(
348
- messages,
349
- tokenize=False,
350
- add_generation_prompt=True
351
- )
352
- return tokenized_chat
353
-
354
- # health_check.py
355
- import psutil
356
- from dataclasses import dataclass
357
- from typing import Dict, Any
358
-
359
- @dataclass
360
- class HealthStatus:
361
- status: str
362
- gpu_memory: Dict[str, float]
363
- cpu_usage: float
364
- ram_usage: float
365
- model_status: Dict[str, str]
366
-
367
- class HealthCheck:
368
- @staticmethod
369
- def check_gpu_memory() -> Dict[str, float]:
370
- if torch.cuda.is_available():
371
- return {
372
- f"gpu_{i}": torch.cuda.memory_allocated(i) / 1024**3
373
- for i in range(torch.cuda.device_count())
374
- }
375
- return {}
376
-
377
- @staticmethod
378
- def check_system_resources() -> HealthStatus:
379
- return HealthStatus(
380
- status="healthy",
381
- gpu_memory=HealthCheck.check_gpu_memory(),
382
- cpu_usage=psutil.cpu_percent(),
383
- ram_usage=psutil.virtual_memory().percent,
384
- #TODO add more system resources like disk, network, etc.
385
- model_status={} # To be filled by the model manager
386
- )
387
-
388
-
389
- # llama_generator.py
390
- from config.config import GenerationConfig, ModelConfig
391
-
392
- @observe()
393
- class LlamaGenerator(BaseGenerator):
394
- def __init__(
395
- self,
396
- llama_model_name: str,
397
- prm_model_path: str,
398
- device: Optional[str] = None,
399
- default_generation_config: Optional[GenerationConfig] = None,
400
- model_config: Optional[ModelConfig] = None,
401
- cache_size: int = 1000,
402
- max_batch_size: int = 32,
403
- # self.tokenizer = self.load_tokenizer(llama_model_name)
404
- # self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
405
-
406
- ):
407
-
408
- @observe()
409
- def load_model(self, model_name: str):
410
- # Code to load your model, e.g., Hugging Face's transformers library
411
- from transformers import AutoModelForCausalLM
412
- return AutoModelForCausalLM.from_pretrained(model_name)
413
-
414
- @observe()
415
- def load_tokenizer(self, model_name: str):
416
- # Load the tokenizer associated with the model
417
- from transformers import AutoTokenizer
418
- return AutoTokenizer.from_pretrained(model_name)
419
-
420
- self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
421
-
422
- super().__init__(
423
- llama_model_name,
424
- device,
425
- default_generation_config,
426
- model_config,
427
- cache_size,
428
- max_batch_size
429
- )
430
-
431
- # Initialize models
432
- self.model_manager.load_model(
433
- "llama",
434
- llama_model_name,
435
- "llama",
436
- self.model_config
437
- )
438
- self.model_manager.load_model(
439
- "prm",
440
- prm_model_path,
441
- "gguf",
442
- self.model_config
443
- )
444
-
445
- self.prompt_builder = LlamaPromptTemplate()
446
- self._init_strategies()
447
-
448
- def _init_strategies(self):
449
- self.strategies = {
450
- "default": DefaultStrategy(),
451
- "majority_voting": MajorityVotingStrategy(),
452
- "best_of_n": BestOfN(),
453
- "beam_search": BeamSearch(),
454
- "dvts": DVT(),
455
- }
456
-
457
- def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
458
- """Get generation kwargs based on config."""
459
- return {
460
- key: getattr(config, key)
461
- for key in [
462
- "max_new_tokens",
463
- "temperature",
464
- "top_p",
465
- "top_k",
466
- "repetition_penalty",
467
- "length_penalty",
468
- "do_sample"
469
- ]
470
- if hasattr(config, key)
471
- }
472
-
473
- @observe()
474
- def generate_stream (self):
475
- return " NOt implememnted yet "
476
-
477
- @observe()
478
- def generate(
479
- self,
480
- prompt: str,
481
- model_kwargs: Dict[str, Any],
482
- strategy: str = "default",
483
- **kwargs
484
- ) -> str:
485
- """
486
- Generate text based on a given strategy.
487
-
488
- Args:
489
- prompt (str): Input prompt for text generation.
490
- model_kwargs (Dict[str, Any]): Additional arguments for model generation.
491
- strategy (str): The generation strategy to use (default: "default").
492
- **kwargs: Additional arguments passed to the strategy.
493
-
494
- Returns:
495
- str: Generated text response.
496
-
497
- Raises:
498
- ValueError: If the specified strategy is not available.
499
- """
500
- # Validate that the strategy exists
501
- if strategy not in self.strategies:
502
- raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}")
503
-
504
- # Extract `generator` from kwargs if it exists to prevent duplication
505
- kwargs.pop("generator", None)
506
-
507
- # Call the selected strategy with the provided arguments
508
- return self.strategies[strategy].generate(
509
- generator=self, # The generator instance
510
- prompt=prompt, # The input prompt
511
- model_kwargs=model_kwargs, # Arguments for the model
512
- **kwargs # Any additional strategy-specific arguments
513
- )
514
-
515
- @observe()
516
- def generate_with_context(
517
- self,
518
- context: str,
519
- user_input: str,
520
- chat_history: List[Tuple[str, str]],
521
- model_kwargs: Dict[str, Any],
522
- max_history_turns: int = 3,
523
- strategy: str = "default",
524
- num_samples: int = 5,
525
- depth: int = 3,
526
- breadth: int = 2,
527
-
528
- ) -> str:
529
- """Generate a response using context and chat history.
530
-
531
- Args:
532
- context (str): Context for the conversation
533
- user_input (str): Current user input
534
- chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
535
- model_kwargs (dict): Additional arguments for model.generate()
536
- max_history_turns (int): Maximum number of history turns to include
537
- strategy (str): Generation strategy
538
- num_samples (int): Number of samples for applicable strategies
539
- depth (int): Depth for DVTS strategy
540
- breadth (int): Breadth for DVTS strategy
541
-
542
- Returns:
543
- str: Generated response
544
- """
545
- prompt = self.prompt_builder.format(
546
- context,
547
- user_input,
548
- chat_history,
549
- max_history_turns
550
- )
551
- return self.generate(
552
- generator=self,
553
- prompt=prompt,
554
- model_kwargs=model_kwargs,
555
- strategy=strategy,
556
- num_samples=num_samples,
557
- depth=depth,
558
- breadth=breadth
559
- )
560
-
561
-
562
-
563
- def check_health(self) -> HealthStatus:
564
- """Check the health status of the generator."""
565
- return self.health_check.check_system_resources() # TODO add model status
566
-
567
-
568
  ###################
569
  #################
570
 
 
19
 
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ###################
23
  #################
24