File size: 7,211 Bytes
f887b2e
 
a562c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f887b2e
 
 
 
a562c0d
f887b2e
 
 
 
a562c0d
 
 
 
f887b2e
a562c0d
 
 
f887b2e
a562c0d
f887b2e
 
a562c0d
f887b2e
a562c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f887b2e
 
 
a562c0d
f887b2e
 
 
 
a562c0d
f887b2e
 
a562c0d
f887b2e
a562c0d
 
f887b2e
a562c0d
 
f887b2e
 
 
a562c0d
 
f887b2e
 
 
 
 
 
 
 
 
 
 
 
a562c0d
f887b2e
a562c0d
 
 
 
 
 
 
 
 
f887b2e
a562c0d
 
f887b2e
 
a562c0d
 
 
 
f887b2e
 
 
a562c0d
f887b2e
a562c0d
f887b2e
 
 
 
 
 
 
 
a562c0d
 
f887b2e
 
a562c0d
 
 
 
 
f887b2e
 
 
 
 
 
a562c0d
 
f887b2e
a562c0d
 
 
 
 
 
f887b2e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import psutil
import os
import time
from typing import Dict, Any
import numpy as np

class MemoryTracker:
    @staticmethod
    def get_memory_usage() -> Dict[str, float]:
        """Get current memory usage statistics."""
        process = psutil.Process(os.getpid())
        memory_info = process.memory_info()
        
        return {
            'rss': memory_info.rss / (1024 * 1024),  # RSS in MB
            'vms': memory_info.vms / (1024 * 1024),  # VMS in MB
            'gpu': torch.cuda.memory_allocated() / (1024 * 1024) if torch.cuda.is_available() else 0  # GPU memory in MB
        }
    
    @staticmethod
    def format_memory_stats(stats: Dict[str, float]) -> str:
        """Format memory statistics into a readable string."""
        return (f"RSS Memory: {stats['rss']:.2f} MB\n"
                f"Virtual Memory: {stats['vms']:.2f} MB\n"
                f"GPU Memory: {stats['gpu']:.2f} MB")

class CustomerSupportBot:
    def __init__(self, model_path="models/customer_support_gpt"):
        """
        Initialize the customer support bot with the fine-tuned model and memory tracking.
        
        Args:
            model_path (str): Path to the saved model and tokenizer
        """
        # Record initial memory state
        self.initial_memory = MemoryTracker.get_memory_usage()
        
        # Load tokenizer and track memory
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.post_tokenizer_memory = MemoryTracker.get_memory_usage()
        
        # Load model and track memory
        self.model = AutoModelForCausalLM.from_pretrained(model_path)
        self.post_model_memory = MemoryTracker.get_memory_usage()
        
        # Move model to GPU if available
        self.device = "cpu"#"cuda" if torch.cuda.is_available() else "cpu"
        self.model = self.model.to(self.device)
        self.post_device_memory = MemoryTracker.get_memory_usage()
        
        # Calculate memory deltas
        self.memory_deltas = {
            'tokenizer_load': {k: self.post_tokenizer_memory[k] - self.initial_memory[k] 
                             for k in self.initial_memory},
            'model_load': {k: self.post_model_memory[k] - self.post_tokenizer_memory[k] 
                          for k in self.initial_memory},
            'device_transfer': {k: self.post_device_memory[k] - self.post_model_memory[k] 
                              for k in self.initial_memory}
        }
        
        # Initialize inference memory tracking
        self.inference_memory_stats = []
    
    def get_memory_report(self) -> str:
        """Generate a comprehensive memory usage report."""
        report = ["Memory Usage Report:"]
        
        report.append("\nModel Loading Memory Changes:")
        report.append("Tokenizer Loading:")
        report.append(MemoryTracker.format_memory_stats(self.memory_deltas['tokenizer_load']))
        
        report.append("\nModel Loading:")
        report.append(MemoryTracker.format_memory_stats(self.memory_deltas['model_load']))
        
        report.append("\nDevice Transfer:")
        report.append(MemoryTracker.format_memory_stats(self.memory_deltas['device_transfer']))
        
        if self.inference_memory_stats:
            avg_inference_memory = {
                k: np.mean([stats[k] for stats in self.inference_memory_stats])
                for k in self.inference_memory_stats[0]
            }
            report.append("\nAverage Inference Memory Usage:")
            report.append(MemoryTracker.format_memory_stats(avg_inference_memory))
        
        return "\n".join(report)
        
    def generate_response(self, instruction, max_length=100, temperature=0.7):
        """
        Generate a response for a given customer support instruction/query with memory tracking.
        
        Args:
            instruction (str): Customer's query or instruction
            max_length (int): Maximum length of the generated response
            temperature (float): Controls randomness in generation
            
        Returns:
            tuple: (Generated response, Memory usage statistics)
        """
        # Record pre-inference memory
        pre_inference_memory = MemoryTracker.get_memory_usage()
        
        # Format and tokenize input
        input_text = f"Instruction: {instruction}\nResponse:"
        inputs = self.tokenizer(input_text, return_tensors="pt")
        inputs = inputs.to(self.device)
        
        # Generate response and track memory
        start_time = time.time()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                temperature=temperature,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                do_sample=True,
                top_p=0.95,
                top_k=50
            )
        inference_time = time.time() - start_time
        
        # Record post-inference memory
        post_inference_memory = MemoryTracker.get_memory_usage()
        
        # Calculate memory delta for this inference
        inference_memory_delta = {
            k: post_inference_memory[k] - pre_inference_memory[k]
            for k in pre_inference_memory
        }
        self.inference_memory_stats.append(inference_memory_delta)
        
        # Decode response
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response.split("Response:")[-1].strip()
        
        return response, {
            'memory_delta': inference_memory_delta,
            'inference_time': inference_time
        }

def main():
    # Initialize the bot
    print("Initializing bot and tracking memory usage...")
    bot = CustomerSupportBot()
    print(bot.get_memory_report())
    
    # Example queries
    example_queries = [
        "How do I reset my password?",
        "What are your shipping policies?",
        "I want to return a product.",
    ]
    
    # Generate and print responses with memory stats
    print("\nCustomer Support Bot Demo:\n")
    for query in example_queries:
        print(f"Customer: {query}")
        response, stats = bot.generate_response(query)
        print(f"Bot: {response}")
        print(f"Inference Memory Delta: {MemoryTracker.format_memory_stats(stats['memory_delta'])}")
        print(f"Inference Time: {stats['inference_time']:.2f} seconds\n")
    
    # Interactive mode
    print("Enter your questions (type 'quit' to exit):")
    while True:
        query = input("\nYour question: ")
        if query.lower() == 'quit':
            break
        
        response, stats = bot.generate_response(query)
        print(f"Bot: {response}")
        print(f"Inference Memory Delta: {MemoryTracker.format_memory_stats(stats['memory_delta'])}")
        print(f"Inference Time: {stats['inference_time']:.2f} seconds")
    
    # Print final memory report
    print("\nFinal Memory Report:")
    print(bot.get_memory_report())

if __name__ == "__main__":
    main()