File size: 5,644 Bytes
087ce88
5f0bb6b
087ce88
 
 
69455b9
087ce88
 
 
 
 
 
 
 
 
 
69455b9
 
 
 
087ce88
 
 
1f37a6a
69455b9
1f37a6a
 
 
 
087ce88
 
 
 
 
 
 
 
 
 
 
a307172
69455b9
 
087ce88
 
 
 
 
 
 
 
1f37a6a
 
087ce88
 
 
 
 
 
 
 
1f37a6a
087ce88
5f0bb6b
 
 
 
 
 
 
 
a307172
087ce88
 
 
5f0bb6b
087ce88
a307172
 
69455b9
087ce88
 
 
1f37a6a
 
087ce88
 
 
 
a307172
087ce88
 
1f37a6a
 
 
a307172
 
 
 
 
 
 
 
93aa8dc
1f37a6a
087ce88
a307172
1f37a6a
93aa8dc
 
 
 
 
 
 
 
 
a307172
 
 
93aa8dc
1f37a6a
a307172
 
 
 
93aa8dc
 
 
 
1f37a6a
 
93aa8dc
087ce88
 
 
1f37a6a
a307172
 
1f37a6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from huggingface_hub import login
from .config import Config
import os

logger = logging.getLogger(__name__)

class ModelManager:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Ensure offline mode is disabled
        os.environ['HF_HUB_OFFLINE'] = '0'
        os.environ['TRANSFORMERS_OFFLINE'] = '0'
        
        # Login to Hugging Face Hub
        if Config.HUGGING_FACE_TOKEN:
            logger.info("Logging in to Hugging Face Hub")
            try:
                login(token=Config.HUGGING_FACE_TOKEN, add_to_git_credential=False)
                logger.info("Successfully logged in to Hugging Face Hub")
            except Exception as e:
                logger.error(f"Failed to login to Hugging Face Hub: {str(e)}")
                raise
        
        # Initialize tokenizer and model
        self._init_tokenizer()
        self._init_model()
        
    def _init_tokenizer(self):
        """Initialize the tokenizer."""
        try:
            logger.info(f"Loading tokenizer: {self.model_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                token=Config.HUGGING_FACE_TOKEN,
                model_max_length=1024,  # Limit max length to save memory
                trust_remote_code=True
            )
            # Ensure we have the necessary special tokens
            special_tokens = {
                'pad_token': '[PAD]',
                'eos_token': '</s>',
                'bos_token': '<s>'
            }
            self.tokenizer.add_special_tokens(special_tokens)
            logger.info("Tokenizer loaded successfully")
            logger.debug(f"Tokenizer vocabulary size: {len(self.tokenizer)}")
        except Exception as e:
            logger.error(f"Error loading tokenizer: {str(e)}")
            raise
            
    def _init_model(self):
        """Initialize the model."""
        try:
            logger.info(f"Loading model: {self.model_name}")
            logger.info(f"Using device: {self.device}")
            
            # Configure 4-bit quantization
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            
            # Load model with memory optimizations
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                device_map={"": self.device},
                quantization_config=quantization_config,
                token=Config.HUGGING_FACE_TOKEN,
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,  # Use fp16 for additional memory savings
                trust_remote_code=True
            )
            # Resize embeddings to match tokenizer
            self.model.resize_token_embeddings(len(self.tokenizer))
            logger.info("Model loaded successfully")
            logger.debug(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}")
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise

    def generate_text(self, prompt: str, max_new_tokens: int = 512) -> str:
        """Generate text from prompt."""
        try:
            logger.info("Starting text generation")
            logger.debug(f"Prompt length: {len(prompt)}")
            
            # Encode the prompt with reduced max length
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=512,  # Reduced max length
                padding=True
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            logger.debug(f"Input tensor shape: {inputs['input_ids'].shape}")

            # Generate response with memory optimizations
            logger.info("Generating response")
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=Config.TEMPERATURE,
                    top_p=Config.TOP_P,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    num_beams=1,  # Disable beam search to save memory
                    use_cache=False,  # Disable KV cache
                    early_stopping=True
                )
            
            # Clear CUDA cache after generation
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Decode and return the generated text
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = generated_text[len(prompt):].strip()
            
            logger.info("Text generation completed")
            logger.debug(f"Response length: {len(response)}")
            return response
            
        except Exception as e:
            logger.error(f"Error generating text: {str(e)}")
            logger.error(f"Error details: {type(e).__name__}")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            raise