File size: 5,242 Bytes
087ce88
9eddb40
087ce88
 
 
69455b9
087ce88
 
 
 
 
 
 
 
 
 
69455b9
 
 
 
087ce88
 
 
1f37a6a
69455b9
1f37a6a
 
 
 
087ce88
 
 
 
 
 
 
 
 
 
 
a307172
69455b9
 
087ce88
 
 
 
 
 
 
 
1f37a6a
 
087ce88
 
 
 
 
 
 
 
1f37a6a
087ce88
a307172
087ce88
 
 
9eddb40
087ce88
a307172
69455b9
087ce88
 
 
1f37a6a
 
087ce88
 
 
 
a307172
087ce88
 
1f37a6a
 
 
a307172
 
 
 
 
 
 
 
93aa8dc
1f37a6a
087ce88
a307172
1f37a6a
93aa8dc
 
 
 
 
 
 
 
 
a307172
9eddb40
a307172
93aa8dc
1f37a6a
a307172
 
 
 
93aa8dc
 
 
 
1f37a6a
 
93aa8dc
087ce88
 
 
1f37a6a
a307172
 
1f37a6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import login
from .config import Config
import os

logger = logging.getLogger(__name__)

class ModelManager:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Ensure offline mode is disabled
        os.environ['HF_HUB_OFFLINE'] = '0'
        os.environ['TRANSFORMERS_OFFLINE'] = '0'
        
        # Login to Hugging Face Hub
        if Config.HUGGING_FACE_TOKEN:
            logger.info("Logging in to Hugging Face Hub")
            try:
                login(token=Config.HUGGING_FACE_TOKEN, add_to_git_credential=False)
                logger.info("Successfully logged in to Hugging Face Hub")
            except Exception as e:
                logger.error(f"Failed to login to Hugging Face Hub: {str(e)}")
                raise
        
        # Initialize tokenizer and model
        self._init_tokenizer()
        self._init_model()
        
    def _init_tokenizer(self):
        """Initialize the tokenizer."""
        try:
            logger.info(f"Loading tokenizer: {self.model_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                token=Config.HUGGING_FACE_TOKEN,
                model_max_length=1024,  # Limit max length to save memory
                trust_remote_code=True
            )
            # Ensure we have the necessary special tokens
            special_tokens = {
                'pad_token': '[PAD]',
                'eos_token': '</s>',
                'bos_token': '<s>'
            }
            self.tokenizer.add_special_tokens(special_tokens)
            logger.info("Tokenizer loaded successfully")
            logger.debug(f"Tokenizer vocabulary size: {len(self.tokenizer)}")
        except Exception as e:
            logger.error(f"Error loading tokenizer: {str(e)}")
            raise
            
    def _init_model(self):
        """Initialize the model."""
        try:
            logger.info(f"Loading model: {self.model_name}")
            logger.info(f"Using device: {self.device}")
            
            # Load model with memory optimizations
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                device_map={"": self.device},
                torch_dtype=torch.float32,
                token=Config.HUGGING_FACE_TOKEN,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            # Resize embeddings to match tokenizer
            self.model.resize_token_embeddings(len(self.tokenizer))
            logger.info("Model loaded successfully")
            logger.debug(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}")
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise

    def generate_text(self, prompt: str, max_new_tokens: int = 512) -> str:
        """Generate text from prompt."""
        try:
            logger.info("Starting text generation")
            logger.debug(f"Prompt length: {len(prompt)}")
            
            # Encode the prompt with reduced max length
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=512,  # Reduced max length
                padding=True
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            logger.debug(f"Input tensor shape: {inputs['input_ids'].shape}")

            # Generate response with memory optimizations
            logger.info("Generating response")
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=Config.TEMPERATURE,
                    top_p=Config.TOP_P,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    num_beams=1,  # Disable beam search to save memory
                    use_cache=True,  # Enable KV cache for faster generation
                    early_stopping=True
                )
            
            # Clear CUDA cache after generation
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Decode and return the generated text
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = generated_text[len(prompt):].strip()
            
            logger.info("Text generation completed")
            logger.debug(f"Response length: {len(response)}")
            return response
            
        except Exception as e:
            logger.error(f"Error generating text: {str(e)}")
            logger.error(f"Error details: {type(e).__name__}")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            raise