# model_manager.py import torch from transformers import AutoModelForCausalLM, AutoTokenizer from llama_cpp import Llama from typing import Optional, Dict import logging from functools import lru_cache from config.config import GenerationConfig, ModelConfig class ModelManager: def __init__(self, device: Optional[str] = None): self.logger = logging.getLogger(__name__) self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.models: Dict[str, Any] = {} self.tokenizers: Dict[str, Any] = {} def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None: """Load a model with specified configuration.""" try: ##could be differnt models, so we can use a factory pattern to load the correct model - textgen, llama, gguf, text2video, text2image etc. if model_type == "llama": self.tokenizers[model_id] = AutoTokenizer.from_pretrained( model_path, padding_side='left', trust_remote_code=True, **config.tokenizer_kwargs ) if self.tokenizers[model_id].pad_token is None: self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token self.models[model_id] = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", trust_remote_code=True, **config.model_kwargs ) elif model_type == "gguf": #TODO load the model first from the cache, if not found load the model and save it in the cache #from huggingface_hub import hf_hub_download #prm_model_path = hf_hub_download( # repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", # filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" #) self.models[model_id] = self._load_quantized_model( model_path, **config.quantization_kwargs ) except Exception as e: self.logger.error(f"Failed to load model {model_id}: {str(e)}") raise def unload_model(self, model_id: str) -> None: """Unload a model and free resources.""" if model_id in self.models: del self.models[model_id] if model_id in self.tokenizers: del self.tokenizers[model_id] torch.cuda.empty_cache() def _load_quantized_model(self, model_path: str, **kwargs) -> Llama: """Load a quantized GGUF model.""" try: n_gpu_layers = -1 if torch.cuda.is_available() else 0 model = Llama( model_path=model_path, n_ctx=kwargs.get('n_ctx', 2048), n_batch=kwargs.get('n_batch', 512), n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers), verbose=kwargs.get('verbose', False) ) return model except Exception as e: self.logger.error(f"Failed to load GGUF model: {str(e)}") raise