Spaces:
Sleeping
Sleeping
import logging | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from huggingface_hub import login | |
from .config import Config | |
import os | |
logger = logging.getLogger(__name__) | |
class ModelManager: | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
self.tokenizer = None | |
self.model = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Ensure offline mode is disabled | |
os.environ['HF_HUB_OFFLINE'] = '0' | |
os.environ['TRANSFORMERS_OFFLINE'] = '0' | |
# Login to Hugging Face Hub | |
if Config.HUGGING_FACE_TOKEN: | |
logger.info("Logging in to Hugging Face Hub") | |
try: | |
login(token=Config.HUGGING_FACE_TOKEN, add_to_git_credential=False) | |
logger.info("Successfully logged in to Hugging Face Hub") | |
except Exception as e: | |
logger.error(f"Failed to login to Hugging Face Hub: {str(e)}") | |
raise | |
# Initialize tokenizer and model | |
self._init_tokenizer() | |
self._init_model() | |
def _init_tokenizer(self): | |
"""Initialize the tokenizer.""" | |
try: | |
logger.info(f"Loading tokenizer: {self.model_name}") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
token=Config.HUGGING_FACE_TOKEN, | |
model_max_length=1024, # Limit max length to save memory | |
trust_remote_code=True | |
) | |
# Ensure we have the necessary special tokens | |
special_tokens = { | |
'pad_token': '[PAD]', | |
'eos_token': '</s>', | |
'bos_token': '<s>' | |
} | |
self.tokenizer.add_special_tokens(special_tokens) | |
logger.info("Tokenizer loaded successfully") | |
logger.debug(f"Tokenizer vocabulary size: {len(self.tokenizer)}") | |
except Exception as e: | |
logger.error(f"Error loading tokenizer: {str(e)}") | |
raise | |
def _init_model(self): | |
"""Initialize the model.""" | |
try: | |
logger.info(f"Loading model: {self.model_name}") | |
logger.info(f"Using device: {self.device}") | |
# Load model with memory optimizations | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
device_map={"": self.device}, | |
torch_dtype=torch.float32, | |
token=Config.HUGGING_FACE_TOKEN, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
# Resize embeddings to match tokenizer | |
self.model.resize_token_embeddings(len(self.tokenizer)) | |
logger.info("Model loaded successfully") | |
logger.debug(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
def generate_text(self, prompt: str, max_new_tokens: int = 512) -> str: | |
"""Generate text from prompt.""" | |
try: | |
logger.info("Starting text generation") | |
logger.debug(f"Prompt length: {len(prompt)}") | |
# Encode the prompt with reduced max length | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, # Reduced max length | |
padding=True | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
logger.debug(f"Input tensor shape: {inputs['input_ids'].shape}") | |
# Generate response with memory optimizations | |
logger.info("Generating response") | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=Config.TEMPERATURE, | |
top_p=Config.TOP_P, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
num_beams=1, # Disable beam search to save memory | |
use_cache=True, # Enable KV cache for faster generation | |
early_stopping=True | |
) | |
# Clear CUDA cache after generation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Decode and return the generated text | |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = generated_text[len(prompt):].strip() | |
logger.info("Text generation completed") | |
logger.debug(f"Response length: {len(response)}") | |
return response | |
except Exception as e: | |
logger.error(f"Error generating text: {str(e)}") | |
logger.error(f"Error details: {type(e).__name__}") | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
raise | |