import logging from transformers import AutoTokenizer, AutoModelForCausalLM import torch from huggingface_hub import login from .config import Config logger = logging.getLogger(__name__) class ModelManager: def __init__(self, model_name: str): self.model_name = model_name self.tokenizer = None self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" # Login to Hugging Face Hub if Config.HUGGING_FACE_TOKEN: logger.info("Logging in to Hugging Face Hub") login(token=Config.HUGGING_FACE_TOKEN) # Initialize tokenizer and model self._init_tokenizer() self._init_model() def _init_tokenizer(self): """Initialize the tokenizer.""" try: logger.info(f"Loading tokenizer: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, token=Config.HUGGING_FACE_TOKEN ) # Ensure we have the necessary special tokens special_tokens = { 'pad_token': '[PAD]', 'eos_token': '', 'bos_token': '' } self.tokenizer.add_special_tokens(special_tokens) logger.info("Tokenizer loaded successfully.") except Exception as e: logger.error(f"Error loading tokenizer: {str(e)}") raise def _init_model(self): """Initialize the model.""" try: logger.info(f"Loading model: {self.model_name}") # Load model with CPU configuration self.model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map={"": self.device}, torch_dtype=torch.float32, # Use float32 for CPU token=Config.HUGGING_FACE_TOKEN, low_cpu_mem_usage=True ) # Resize embeddings to match tokenizer self.model.resize_token_embeddings(len(self.tokenizer)) logger.info(f"Using device: {self.device}") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise def generate_text(self, prompt: str, max_new_tokens: int = 1024) -> str: """Generate text from prompt.""" try: # Encode the prompt inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=Config.TEMPERATURE, top_p=Config.TOP_P, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode and return the generated text generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the generated part (remove the prompt) response = generated_text[len(prompt):].strip() return response except Exception as e: logger.error(f"Error generating text: {str(e)}") return """- Issues: - Error generating code review - Model inference failed - Improvements: - Please try again - Check model configuration - Best Practices: - Ensure proper model setup - Verify token permissions - Security: - No immediate concerns"""