Spaces:
Sleeping
Sleeping
File size: 5,242 Bytes
087ce88 9eddb40 087ce88 69455b9 087ce88 69455b9 087ce88 1f37a6a 69455b9 1f37a6a 087ce88 a307172 69455b9 087ce88 1f37a6a 087ce88 1f37a6a 087ce88 a307172 087ce88 9eddb40 087ce88 a307172 69455b9 087ce88 1f37a6a 087ce88 a307172 087ce88 1f37a6a a307172 93aa8dc 1f37a6a 087ce88 a307172 1f37a6a 93aa8dc a307172 9eddb40 a307172 93aa8dc 1f37a6a a307172 93aa8dc 1f37a6a 93aa8dc 087ce88 1f37a6a a307172 1f37a6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import login
from .config import Config
import os
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self, model_name: str):
self.model_name = model_name
self.tokenizer = None
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Ensure offline mode is disabled
os.environ['HF_HUB_OFFLINE'] = '0'
os.environ['TRANSFORMERS_OFFLINE'] = '0'
# Login to Hugging Face Hub
if Config.HUGGING_FACE_TOKEN:
logger.info("Logging in to Hugging Face Hub")
try:
login(token=Config.HUGGING_FACE_TOKEN, add_to_git_credential=False)
logger.info("Successfully logged in to Hugging Face Hub")
except Exception as e:
logger.error(f"Failed to login to Hugging Face Hub: {str(e)}")
raise
# Initialize tokenizer and model
self._init_tokenizer()
self._init_model()
def _init_tokenizer(self):
"""Initialize the tokenizer."""
try:
logger.info(f"Loading tokenizer: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
token=Config.HUGGING_FACE_TOKEN,
model_max_length=1024, # Limit max length to save memory
trust_remote_code=True
)
# Ensure we have the necessary special tokens
special_tokens = {
'pad_token': '[PAD]',
'eos_token': '</s>',
'bos_token': '<s>'
}
self.tokenizer.add_special_tokens(special_tokens)
logger.info("Tokenizer loaded successfully")
logger.debug(f"Tokenizer vocabulary size: {len(self.tokenizer)}")
except Exception as e:
logger.error(f"Error loading tokenizer: {str(e)}")
raise
def _init_model(self):
"""Initialize the model."""
try:
logger.info(f"Loading model: {self.model_name}")
logger.info(f"Using device: {self.device}")
# Load model with memory optimizations
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map={"": self.device},
torch_dtype=torch.float32,
token=Config.HUGGING_FACE_TOKEN,
low_cpu_mem_usage=True,
trust_remote_code=True
)
# Resize embeddings to match tokenizer
self.model.resize_token_embeddings(len(self.tokenizer))
logger.info("Model loaded successfully")
logger.debug(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def generate_text(self, prompt: str, max_new_tokens: int = 512) -> str:
"""Generate text from prompt."""
try:
logger.info("Starting text generation")
logger.debug(f"Prompt length: {len(prompt)}")
# Encode the prompt with reduced max length
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512, # Reduced max length
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
logger.debug(f"Input tensor shape: {inputs['input_ids'].shape}")
# Generate response with memory optimizations
logger.info("Generating response")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=Config.TEMPERATURE,
top_p=Config.TOP_P,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_beams=1, # Disable beam search to save memory
use_cache=True, # Enable KV cache for faster generation
early_stopping=True
)
# Clear CUDA cache after generation
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Decode and return the generated text
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = generated_text[len(prompt):].strip()
logger.info("Text generation completed")
logger.debug(f"Response length: {len(response)}")
return response
except Exception as e:
logger.error(f"Error generating text: {str(e)}")
logger.error(f"Error details: {type(e).__name__}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise
|