Spaces:
Sleeping
Sleeping
import logging | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import torch | |
from huggingface_hub import login | |
from .config import Config | |
logger = logging.getLogger(__name__) | |
class ModelManager: | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
self.tokenizer = None | |
self.model = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Login to Hugging Face Hub | |
if Config.HUGGING_FACE_TOKEN: | |
logger.info("Logging in to Hugging Face Hub") | |
try: | |
login(token=Config.HUGGING_FACE_TOKEN) | |
logger.info("Successfully logged in to Hugging Face Hub") | |
except Exception as e: | |
logger.error(f"Failed to login to Hugging Face Hub: {str(e)}") | |
raise | |
# Initialize tokenizer and model | |
self._init_tokenizer() | |
self._init_model() | |
def _init_tokenizer(self): | |
"""Initialize the tokenizer.""" | |
try: | |
logger.info(f"Loading tokenizer: {self.model_name}") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
token=Config.HUGGING_FACE_TOKEN | |
) | |
# Ensure we have the necessary special tokens | |
special_tokens = { | |
'pad_token': '[PAD]', | |
'eos_token': '</s>', | |
'bos_token': '<s>' | |
} | |
self.tokenizer.add_special_tokens(special_tokens) | |
logger.info("Tokenizer loaded successfully") | |
logger.debug(f"Tokenizer vocabulary size: {len(self.tokenizer)}") | |
except Exception as e: | |
logger.error(f"Error loading tokenizer: {str(e)}") | |
raise | |
def _init_model(self): | |
"""Initialize the model.""" | |
try: | |
logger.info(f"Loading model: {self.model_name}") | |
logger.info(f"Using device: {self.device}") | |
# Load model with CPU configuration | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
device_map={"": self.device}, | |
torch_dtype=torch.float32, # Use float32 for CPU | |
token=Config.HUGGING_FACE_TOKEN, | |
low_cpu_mem_usage=True | |
) | |
# Resize embeddings to match tokenizer | |
self.model.resize_token_embeddings(len(self.tokenizer)) | |
logger.info("Model loaded successfully") | |
logger.debug(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
def generate_text(self, prompt: str, max_new_tokens: int = 1024) -> str: | |
"""Generate text from prompt.""" | |
try: | |
logger.info("Starting text generation") | |
logger.debug(f"Prompt length: {len(prompt)}") | |
# Encode the prompt | |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
logger.debug(f"Input tensor shape: {inputs['input_ids'].shape}") | |
# Generate response | |
logger.info("Generating response") | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=Config.TEMPERATURE, | |
top_p=Config.TOP_P, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
) | |
# Decode and return the generated text | |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = generated_text[len(prompt):].strip() | |
logger.info("Text generation completed") | |
logger.debug(f"Response length: {len(response)}") | |
return response | |
except Exception as e: | |
logger.error(f"Error generating text: {str(e)}") | |
logger.error(f"Error details: {type(e).__name__}") | |
raise | |