acecalisto3 commited on
Commit
40667c5
·
verified ·
1 Parent(s): 9ff74fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -107
app.py CHANGED
@@ -1,138 +1,227 @@
1
- import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM
3
  import os
4
  import json
5
- import time
6
  import logging
7
  from threading import Lock
 
8
 
9
- CONFIG_FILE = "config.json"
10
- MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
11
- CACHE_DIR = "model_cache" # Directory for storing model cache
12
-
13
- # Create cache directory if it doesn't exist
14
- os.makedirs(CACHE_DIR, exist_ok=True)
15
-
16
- # Setup logging
17
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
-
19
- messages = [
20
- {"role": "user", "content": "Who are you?"},
21
- ]
22
 
23
  class EnhancedChatbot:
24
  def __init__(self):
25
  self.model = None
26
- self.config = self.load_config()
27
  self.model_lock = Lock()
28
- self.load_model()
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def load_config(self):
31
- if os.path.exists(CONFIG_FILE):
32
- with open(CONFIG_FILE, 'r') as f:
33
- return json.load(f)
34
- return {
35
- "model_name": MODEL_NAME,
36
- "max_tokens": 512,
37
  "temperature": 0.7,
38
  "top_p": 0.95,
39
- "system_message": "You are a friendly and helpful AI assistant.",
40
- "gpu_layers": 0
 
 
41
  }
42
-
43
- def save_config(self):
44
- with open(CONFIG_FILE, 'w') as f:
45
- json.dump(self.config, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def load_model(self):
 
48
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  self.model = AutoModelForCausalLM.from_pretrained(
50
  self.config["model_name"],
51
- model_type="llama",
52
- gpu_layers=self.config["gpu_layers"],
53
- cache_dir=CACHE_DIR
 
 
 
54
  )
55
- logging.info(f"Model loaded successfully: {self.config['model_name']}")
 
 
56
  except Exception as e:
57
  logging.error(f"Error loading model: {str(e)}")
58
  raise
59
 
60
- def generate_response(self, message, history, system_message, max_tokens, temperature, top_p):
61
- prompt = f"{system_message}\n\n"
62
- for user_msg, assistant_msg in history:
63
- prompt += f"Human: {user_msg}\nAssistant: {assistant_msg}\n"
64
- prompt += f"Human: {message}\nAssistant: "
65
-
66
- start_time = time.time()
67
- with self.model_lock:
68
- generated_text = self.model(
69
- prompt,
70
- max_new_tokens=max_tokens,
71
- temperature=temperature,
72
- top_p=top_p,
73
- )
74
- end_time = time.time()
75
-
76
- response_time = end_time - start_time
77
- logging.info(f"Response generated in {response_time:.2f} seconds")
78
-
79
- return generated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- chatbot = EnhancedChatbot()
 
 
82
 
83
- def respond(message, history, system_message, max_tokens, temperature, top_p):
84
- try:
85
- response = chatbot.generate_response(message, history, system_message, max_tokens, temperature, top_p)
86
- yield response
87
- except Exception as e:
88
- logging.error(f"Error generating response: {str(e)}")
89
- yield "I apologize, but I encountered an error while processing your request. Please try again."
90
-
91
- def update_model_config(model_name, gpu_layers):
92
- chatbot.config["model_name"] = model_name
93
- chatbot.config["gpu_layers"] = gpu_layers
94
- chatbot.save_config()
95
- chatbot.load_model()
96
- return f"Model updated to {model_name} with {gpu_layers} GPU layers."
97
-
98
- def update_system_message(system_message):
99
- chatbot.config["system_message"] = system_message
100
- chatbot .save_config()
101
- return f"System message updated: {system_message}"
102
-
103
- with gr.Blocks() as demo:
104
- gr.Markdown("# Enhanced AI Chatbot")
105
-
106
- with gr.Tab("Chat"):
107
- chatbot_interface= gr.ChatInterface(
108
- respond,
109
- additional_inputs=[
110
- gr.Textbox(value=chatbot.config["system_message"], label="System message"),
111
- gr.Slider(minimum=1, maximum=2048, value=chatbot.config["max_tokens"], step=1, label="Max new tokens"),
112
- gr.Slider(minimum=0.1, maximum=4.0, value=chatbot.config["temperature"], step=0.1, label="Temperature"),
113
- gr.Slider(
114
- minimum=0.1,
115
- maximum=1.0,
116
- value=chatbot.config["top_p"],
117
- step=0.05,
118
- label="Top-p (nucleus sampling)",
119
- ),
120
- ],
121
- )
122
 
123
- with gr.Tab("Settings"):
124
- with gr.Group():
125
- gr.Markdown("### Model Settings")
126
- model_name_input = gr.Textbox(value=chatbot.config["model_name"], label="Model name")
127
- gpu_layers_input = gr.Slider(minimum=0, maximum=8, value=chatbot.config["gpu_layers"], step=1, label="GPU layers")
128
- update_model_button = gr.Button("Update model")
129
- update_model_button.click(update_model_config, inputs=[model_name_input, gpu_layers_input], outputs="text")
130
 
131
- with gr.Group():
132
- gr.Markdown("### System Message Settings")
133
- system_message_input = gr.Textbox(value=chatbot.config["system_message"], label="System message")
134
- update_system_message_button = gr.Button("Update system message")
135
- update_system_message_button.click(update_system_message, inputs=[system_message_input], outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import os
4
  import json
 
5
  import logging
6
  from threading import Lock
7
+ import torch
8
 
9
+ # Constants with optimized values for Mixtral
10
+ DEFAULT_MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
11
+ MAX_INPUT_TOKENS = 24576 # 24K tokens for input (leaving room for output)
12
+ MAX_NEW_TOKENS = 8192 # 8K tokens for generation
13
+ DEFAULT_CONTEXT_LENGTH = 16384 # 16K default context
14
+ CONFIG_FILE = "chatbot_config.json"
15
+ CACHE_DIR = "model_cache"
 
 
 
 
 
 
16
 
17
  class EnhancedChatbot:
18
  def __init__(self):
19
  self.model = None
20
+ self.tokenizer = None
21
  self.model_lock = Lock()
22
+
23
+ # Ensure cache directory exists
24
+ os.makedirs(CACHE_DIR, exist_ok=True)
25
+
26
+ # Initialize configuration with higher limits
27
+ self.config = self.load_config()
28
+
29
+ # Initialize model and tokenizer
30
+ try:
31
+ self.load_model()
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {str(e)}")
34
+ logging.error(f"Error loading model: {str(e)}")
35
 
36
  def load_config(self):
37
+ """Load or create configuration file with optimized settings"""
38
+ default_config = {
39
+ "model_name": DEFAULT_MODEL_NAME,
40
+ "max_new_tokens": MAX_NEW_TOKENS,
41
+ "context_length": DEFAULT_CONTEXT_LENGTH,
 
42
  "temperature": 0.7,
43
  "top_p": 0.95,
44
+ "top_k": 50,
45
+ "repetition_penalty": 1.1,
46
+ "system_message": "You are a helpful AI assistant with high context understanding.",
47
+ "gpu_layers": "auto"
48
  }
49
+
50
+ try:
51
+ if os.path.exists(CONFIG_FILE):
52
+ with open(CONFIG_FILE, 'r') as f:
53
+ config = json.load(f)
54
+ # Update with any missing keys from default_config
55
+ for key, value in default_config.items():
56
+ if key not in config:
57
+ config[key] = value
58
+ else:
59
+ config = default_config
60
+ self.save_config(config)
61
+
62
+ return config
63
+
64
+ except Exception as e:
65
+ logging.error(f"Error loading config: {str(e)}")
66
+ return default_config
67
 
68
  def load_model(self):
69
+ """Load the model and tokenizer with optimized settings"""
70
  try:
71
+ # Clear CUDA cache if using GPU
72
+ if torch.cuda.is_available():
73
+ torch.cuda.empty_cache()
74
+
75
+ # Load tokenizer first
76
+ self.tokenizer = AutoTokenizer.from_pretrained(
77
+ self.config["model_name"],
78
+ cache_dir=CACHE_DIR,
79
+ model_max_length=self.config["context_length"],
80
+ padding_side="left"
81
+ )
82
+
83
+ # Load model with optimized settings
84
  self.model = AutoModelForCausalLM.from_pretrained(
85
  self.config["model_name"],
86
+ torch_dtype=torch.bfloat16, # Use bfloat16 for better performance
87
+ low_cpu_mem_usage=True,
88
+ cache_dir=CACHE_DIR,
89
+ device_map="auto",
90
+ max_memory={0: "24GiB"}, # Adjust based on your GPU
91
+ trust_remote_code=True
92
  )
93
+
94
+ logging.info(f"Model {self.config['model_name']} loaded successfully")
95
+
96
  except Exception as e:
97
  logging.error(f"Error loading model: {str(e)}")
98
  raise
99
 
100
+ def generate_response(self, message, history):
101
+ """Generate response with high token limit"""
102
+ try:
103
+ with self.model_lock:
104
+ # Prepare conversation history
105
+ full_prompt = self.prepare_prompt(message, history)
106
+
107
+ # Tokenize with proper handling of long sequences
108
+ inputs = self.tokenizer(full_prompt,
109
+ return_tensors="pt",
110
+ truncation=True,
111
+ max_length=MAX_INPUT_TOKENS)
112
+
113
+ # Move to GPU if available
114
+ inputs = inputs.to(self.model.device)
115
+
116
+ # Generate with optimized parameters
117
+ outputs = self.model.generate(
118
+ **inputs,
119
+ max_new_tokens=self.config["max_new_tokens"],
120
+ temperature=self.config["temperature"],
121
+ top_p=self.config["top_p"],
122
+ top_k=self.config["top_k"],
123
+ repetition_penalty=self.config["repetition_penalty"],
124
+ do_sample=True,
125
+ pad_token_id=self.tokenizer.eos_token_id
126
+ )
127
+
128
+ # Decode response
129
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
130
+
131
+ return response.strip()
132
 
133
+ except Exception as e:
134
+ logging.error(f"Error generating response: {str(e)}")
135
+ return "I apologize, but I encountered an error. Please try again."
136
 
137
+ def prepare_prompt(self, message, history):
138
+ """Prepare prompt with history management"""
139
+ system_msg = self.config["system_message"]
140
+ prompt = f"{system_msg}\n\n"
141
+
142
+ # Add history with token counting
143
+ total_tokens = 0
144
+ for msg in history:
145
+ tokens = len(self.tokenizer.encode(msg["content"]))
146
+ if total_tokens + tokens < MAX_INPUT_TOKENS:
147
+ prompt += f"{msg['role']}: {msg['content']}\n"
148
+ total_tokens += tokens
149
+ else:
150
+ break
151
+
152
+ prompt += f"user: {message}\nassistant:"
153
+ return prompt
154
+
155
+ # Streamlit UI with advanced settings
156
+ def main():
157
+ st.title("Enhanced AI Chatbot (High Context)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ try:
160
+ chatbot = EnhancedChatbot()
 
 
 
 
 
161
 
162
+ # Advanced settings in sidebar
163
+ with st.sidebar:
164
+ st.subheader("Model Settings")
165
+
166
+ # Context length slider
167
+ new_context = st.slider(
168
+ "Context Length (tokens)",
169
+ min_value=1024,
170
+ max_value=32768,
171
+ value=chatbot.config["context_length"],
172
+ step=1024
173
+ )
174
+
175
+ # Generation settings
176
+ new_max_tokens = st.slider(
177
+ "Max New Tokens",
178
+ min_value=1024,
179
+ max_value=MAX_NEW_TOKENS,
180
+ value=chatbot.config["max_new_tokens"],
181
+ step=1024
182
+ )
183
+
184
+ temperature = st.slider(
185
+ "Temperature",
186
+ min_value=0.1,
187
+ max_value=2.0,
188
+ value=chatbot.config["temperature"]
189
+ )
190
+
191
+ # Update settings button
192
+ if st.button("Update Settings"):
193
+ chatbot.config.update({
194
+ "context_length": new_context,
195
+ "max_new_tokens": new_max_tokens,
196
+ "temperature": temperature
197
+ })
198
+ chatbot.save_config(chatbot.config)
199
+ st.experimental_rerun()
200
+
201
+ # Chat interface
202
+ if "messages" not in st.session_state:
203
+ st.session_state.messages = []
204
+
205
+ # Display chat messages
206
+ for message in st.session_state.messages:
207
+ with st.chat_message(message["role"]):
208
+ st.markdown(message["content"])
209
+
210
+ # Chat input
211
+ if prompt := st.chat_input("What would you like to know?"):
212
+ st.session_state.messages.append({"role": "user", "content": prompt})
213
+ with st.chat_message("user"):
214
+ st.markdown(prompt)
215
+
216
+ with st.chat_message("assistant"):
217
+ with st.spinner("Generating response..."):
218
+ response = chatbot.generate_response(prompt, st.session_state.messages)
219
+ st.markdown(response)
220
+ st.session_state.messages.append({"role": "assistant", "content": response})
221
+
222
+ except Exception as e:
223
+ st.error(f"Application Error: {str(e)}")
224
+ logging.error(f"Application Error: {str(e)}")
225
 
226
  if __name__ == "__main__":
227
+ main()