ykarout commited on
Commit
87cfb14
ยท
verified ยท
1 Parent(s): bd4adf7

Initial commit for initial version

Browse files

Simple Chat-UI with a Transformers library back-end for inference

Files changed (1) hide show
  1. app.py +1515 -0
app.py ADDED
@@ -0,0 +1,1515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, BitsAndBytesConfig
4
+ from threading import Thread
5
+ import time
6
+ import logging
7
+ import gc
8
+ from pathlib import Path
9
+ import re
10
+ from huggingface_hub import HfApi, list_models
11
+ import os
12
+ import queue
13
+ import threading
14
+ from collections import deque
15
+
16
+ # Set PyTorch memory management environment variables
17
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler('gradio-chat-ui.log'),
25
+ logging.StreamHandler()
26
+ ]
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Log memory management settings
31
+ logger.info(f"PyTorch CUDA allocation config: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
32
+ logger.info(f"CUDA device count: {torch.cuda.device_count() if torch.cuda.is_available() else 'N/A'}")
33
+
34
+ # Model parameters
35
+ MODEL_NAME = "No Model Loaded"
36
+ MAX_LENGTH = 16384
37
+ DEFAULT_TEMPERATURE = 0.15
38
+ DEFAULT_TOP_P = 0.93
39
+ DEFAULT_TOP_K = 50
40
+ DEFAULT_REP_PENALTY = 1.15
41
+
42
+ # Base location for local models
43
+ LOCAL_MODELS_BASE = "/home/llm-models/"
44
+
45
+ # Global variables
46
+ model = None
47
+ tokenizer = None
48
+ hf_api = HfApi()
49
+
50
+ # Generation metadata storage with automatic cleanup
51
+ generation_metadata = deque(maxlen=100) # Fixed size deque to prevent unlimited growth
52
+
53
+ class RAMSavingIteratorStreamer:
54
+ """
55
+ Custom streamer that saves VRAM by moving tokens to CPU and provides iteration interface for Gradio.
56
+ Combines the benefits of TextStreamer (RAM saving) with TextIteratorStreamer (iteration).
57
+ """
58
+ def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True, timeout=None):
59
+ self.tokenizer = tokenizer
60
+ self.skip_special_tokens = skip_special_tokens
61
+ self.skip_prompt = skip_prompt
62
+ self.timeout = timeout
63
+
64
+ # Token and text storage (CPU-based)
65
+ self.generated_tokens = []
66
+ self.generated_text = ""
67
+ self.token_cache = ""
68
+
69
+ # Queue for streaming interface
70
+ self.text_queue = queue.Queue()
71
+ self.stop_signal = threading.Event()
72
+
73
+ # Track prompt tokens to skip them
74
+ self.prompt_length = 0
75
+ self.tokens_processed = 0
76
+
77
+ # Decoding state
78
+ self.print_len = 0
79
+
80
+ def put(self, value):
81
+ """
82
+ Receive new token(s) and process them for streaming.
83
+ This method is called by the model during generation.
84
+ """
85
+ try:
86
+ # Handle different input types
87
+ if isinstance(value, torch.Tensor):
88
+ if value.dim() > 1:
89
+ value = value[0] # Remove batch dimension if present
90
+ token_ids = value.tolist()
91
+
92
+ # Store CPU version to save VRAM
93
+ self.generated_tokens.append(value.detach().cpu())
94
+ else:
95
+ token_ids = value if isinstance(value, list) else [value]
96
+ self.generated_tokens.append(torch.tensor(token_ids, dtype=torch.long))
97
+
98
+ # Track tokens processed
99
+ if isinstance(token_ids, list):
100
+ self.tokens_processed += len(token_ids)
101
+ else:
102
+ self.tokens_processed += 1
103
+
104
+ # Skip prompt tokens if requested
105
+ if self.skip_prompt and self.tokens_processed <= self.prompt_length:
106
+ return
107
+
108
+ # Decode incrementally for real-time streaming
109
+ try:
110
+ # Get all generated tokens so far
111
+ if self.generated_tokens:
112
+ all_tokens = []
113
+ for tokens in self.generated_tokens:
114
+ if isinstance(tokens, torch.Tensor):
115
+ if tokens.dim() == 0:
116
+ all_tokens.append(tokens.item())
117
+ else:
118
+ all_tokens.extend(tokens.tolist())
119
+ elif isinstance(tokens, list):
120
+ all_tokens.extend(tokens)
121
+ else:
122
+ all_tokens.append(tokens)
123
+
124
+ # Decode the full sequence
125
+ full_text = self.tokenizer.decode(
126
+ all_tokens,
127
+ skip_special_tokens=self.skip_special_tokens
128
+ )
129
+
130
+ # Get new text since last update
131
+ if len(full_text) > self.print_len:
132
+ new_text = full_text[self.print_len:]
133
+ self.print_len = len(full_text)
134
+ self.generated_text = full_text
135
+
136
+ # Put new text in queue for iteration
137
+ if new_text:
138
+ self.text_queue.put(new_text)
139
+
140
+ except Exception as decode_error:
141
+ logger.warning(f"Decoding error in streamer: {decode_error}")
142
+
143
+ except Exception as e:
144
+ logger.error(f"Error in RAMSavingIteratorStreamer.put: {e}")
145
+
146
+ def end(self):
147
+ """Signal end of generation."""
148
+ self.text_queue.put(None) # Sentinel value
149
+
150
+ def __iter__(self):
151
+ """Make this streamer iterable for Gradio compatibility."""
152
+ return self
153
+
154
+ def __next__(self):
155
+ """Get next chunk of text for streaming."""
156
+ try:
157
+ value = self.text_queue.get(timeout=self.timeout)
158
+ if value is None: # End signal
159
+ raise StopIteration
160
+ return value
161
+ except queue.Empty:
162
+ raise StopIteration
163
+
164
+ def set_prompt_length(self, prompt_length):
165
+ """Set the length of prompt tokens to skip."""
166
+ self.prompt_length = prompt_length
167
+
168
+ def get_generated_text(self):
169
+ """Get the complete generated text."""
170
+ return self.generated_text
171
+
172
+ def get_generated_tokens(self):
173
+ """Get all generated tokens as a single tensor."""
174
+ if not self.generated_tokens:
175
+ return torch.tensor([])
176
+
177
+ # Combine all tokens
178
+ all_tokens = []
179
+ for tokens in self.generated_tokens:
180
+ if isinstance(tokens, torch.Tensor):
181
+ if tokens.dim() == 0:
182
+ all_tokens.append(tokens.item())
183
+ else:
184
+ all_tokens.extend(tokens.tolist())
185
+ elif isinstance(tokens, list):
186
+ all_tokens.extend(tokens)
187
+ else:
188
+ all_tokens.append(tokens)
189
+
190
+ return torch.tensor(all_tokens, dtype=torch.long)
191
+
192
+ def cleanup(self):
193
+ """Clean up resources."""
194
+ self.generated_tokens.clear()
195
+ self.generated_text = ""
196
+ self.token_cache = ""
197
+
198
+ # Clear queue
199
+ while not self.text_queue.empty():
200
+ try:
201
+ self.text_queue.get_nowait()
202
+ except queue.Empty:
203
+ break
204
+
205
+ self.stop_signal.set()
206
+
207
+ def scan_local_models(base_path=LOCAL_MODELS_BASE):
208
+ """Scan for valid models in the local models directory"""
209
+ try:
210
+ base_path = Path(base_path)
211
+ if not base_path.exists():
212
+ logger.warning(f"Base path does not exist: {base_path}")
213
+ return []
214
+
215
+ valid_models = []
216
+
217
+ # Scan subdirectories (depth 1 only)
218
+ for item in base_path.iterdir():
219
+ if item.is_dir():
220
+ # Check if directory contains required model files
221
+ config_file = item / "config.json"
222
+
223
+ # Look for model weight files (safetensors or bin)
224
+ safetensors_files = list(item.glob("*.safetensors"))
225
+ bin_files = list(item.glob("*.bin"))
226
+
227
+ # Check if it's a valid model directory
228
+ if config_file.exists() and (safetensors_files or bin_files):
229
+ valid_models.append(str(item))
230
+ logger.info(f"Found valid model: {item}")
231
+
232
+ # Sort models for consistent ordering
233
+ valid_models.sort()
234
+ logger.info(f"Found {len(valid_models)} valid models in {base_path}")
235
+
236
+ return valid_models
237
+
238
+ except Exception as e:
239
+ logger.error(f"Error scanning local models: {e}")
240
+ return []
241
+
242
+ def update_local_models_dropdown(base_path):
243
+ """Update the local models dropdown based on base path"""
244
+ if not base_path or not base_path.strip():
245
+ return gr.Dropdown(choices=[], value=None, interactive=True)
246
+
247
+ models = scan_local_models(base_path)
248
+ model_choices = [Path(model).name for model in models] # Show just the model name
249
+ model_paths = models # Keep full paths for internal use
250
+
251
+ # Create a mapping for display name to full path
252
+ if model_choices:
253
+ return gr.Dropdown(
254
+ choices=list(zip(model_choices, model_paths)),
255
+ value=model_paths[0] if model_paths else None,
256
+ label="๐Ÿ“‹ Available Local Models",
257
+ interactive=True,
258
+ allow_custom_value=False, # Don't allow custom for local models
259
+ filterable=True
260
+ )
261
+ else:
262
+ return gr.Dropdown(
263
+ choices=[],
264
+ value=None,
265
+ label="๐Ÿ“‹ Available Local Models (None found)",
266
+ interactive=True,
267
+ allow_custom_value=False,
268
+ filterable=True
269
+ )
270
+
271
+ def search_hf_models(query, limit=20):
272
+ """Enhanced search for models on Hugging Face Hub with better coverage"""
273
+ if not query or len(query.strip()) < 2:
274
+ return []
275
+
276
+ try:
277
+ query = query.strip()
278
+ model_choices = []
279
+
280
+ # Strategy 1: Direct model ID search (if query looks like a model ID)
281
+ if '/' in query:
282
+ try:
283
+ # Try to get the specific model
284
+ model_info = hf_api.model_info(query)
285
+ if model_info and hasattr(model_info, 'id'):
286
+ model_choices.append(model_info.id)
287
+ logger.info(f"Found direct model: {model_info.id}")
288
+ except Exception as direct_error:
289
+ logger.debug(f"Direct model search failed: {direct_error}")
290
+
291
+ # Strategy 2: Search with different parameters
292
+ search_strategies = [
293
+ # Exact search
294
+ {"search": query, "sort": "downloads", "direction": -1, "limit": limit//2},
295
+ # Author search (if query contains /)
296
+ {"author": query.split('/')[0] if '/' in query else query, "sort": "downloads", "direction": -1, "limit": limit//4} if '/' in query else None,
297
+ # Broader search
298
+ {"search": query, "sort": "trending", "direction": -1, "limit": limit//4},
299
+ ]
300
+
301
+ for strategy in search_strategies:
302
+ if strategy is None:
303
+ continue
304
+
305
+ try:
306
+ models = list_models(
307
+ task="text-generation",
308
+ **strategy
309
+ )
310
+
311
+ for model in models:
312
+ if model.id not in model_choices:
313
+ model_choices.append(model.id)
314
+
315
+ except Exception as strategy_error:
316
+ logger.debug(f"Search strategy failed: {strategy_error}")
317
+
318
+ # Remove duplicates while preserving order
319
+ seen = set()
320
+ unique_choices = []
321
+ for choice in model_choices:
322
+ if choice not in seen:
323
+ seen.add(choice)
324
+ unique_choices.append(choice)
325
+
326
+ # Limit results
327
+ final_choices = unique_choices[:limit]
328
+ logger.info(f"HF search for '{query}' returned {len(final_choices)} models")
329
+
330
+ return final_choices
331
+
332
+ except Exception as e:
333
+ logger.error(f"Error searching models: {str(e)}")
334
+ return []
335
+
336
+ def update_model_dropdown(query):
337
+ """Update dropdown with enhanced search results"""
338
+ if not query or len(query.strip()) < 2:
339
+ return gr.Dropdown(choices=[], value=None, interactive=True)
340
+
341
+ choices = search_hf_models(query, limit=20)
342
+ return gr.Dropdown(
343
+ choices=choices,
344
+ value=choices[0] if choices else None,
345
+ interactive=True,
346
+ allow_custom_value=True, # Allow manual typing
347
+ filterable=True
348
+ )
349
+
350
+ def load_model_with_progress(model_source, hf_model, local_path, local_model_selection, quantization, memory_optimization):
351
+ """Load model with progress tracking and memory optimization"""
352
+ global model, tokenizer, MODEL_NAME
353
+
354
+ # Determine model path based on source
355
+ if model_source == "Hugging Face Model":
356
+ if not hf_model:
357
+ return "โŒ Error: Please select a model from the dropdown"
358
+ model_path = hf_model
359
+ else:
360
+ # Use selected local model if available, otherwise use manual path
361
+ if local_model_selection:
362
+ model_path = local_model_selection
363
+ else:
364
+ model_path = local_path
365
+ if not Path(model_path).exists():
366
+ logger.error(f"Local path does not exist: {model_path}")
367
+ return f"โŒ Error: Local path does not exist: {model_path}"
368
+
369
+ MODEL_NAME = model_path.split("/")[-1] if "/" in model_path else model_path
370
+ logger.info(f"Loading model from {model_path} with memory optimization: {memory_optimization}")
371
+
372
+ try:
373
+ # Yield progress updates
374
+ yield "๐Ÿ”„ Initializing model loading..."
375
+
376
+ # Setup memory configuration (GPU-only, generous allocation)
377
+ if torch.cuda.is_available():
378
+ device_properties = torch.cuda.get_device_properties(0)
379
+ total_memory_gb = device_properties.total_memory / (1024**3)
380
+
381
+ # Set max memory to 11GB as requested (GPU-bound)
382
+ max_memory_val = 11.5 # Fixed 11GB allocation
383
+ max_memory = f"{max_memory_val}GB"
384
+ logger.info(f"Setting max GPU memory to {max_memory} (Total available: {total_memory_gb:.2f}GB)")
385
+ else:
386
+ max_memory = "11GB"
387
+ logger.info("CUDA not available. Using CPU fallback.")
388
+
389
+ yield "๐Ÿ”„ Configuring quantization settings..."
390
+
391
+ # Configure quantization (removed CPU offloading)
392
+ bnb_config = BitsAndBytesConfig(
393
+ load_in_4bit=quantization == "4bit",
394
+ load_in_8bit=quantization == "8bit",
395
+ bnb_4bit_use_double_quant=True,
396
+ bnb_4bit_compute_dtype=torch.bfloat16,
397
+ bnb_4bit_quant_type="nf4",
398
+ )
399
+
400
+ yield "๐Ÿ”„ Loading tokenizer..."
401
+
402
+ # Load tokenizer
403
+ if model_source == "Local Path":
404
+ tokenizer = AutoTokenizer.from_pretrained(
405
+ model_path,
406
+ trust_remote_code=True,
407
+ local_files_only=True
408
+ )
409
+ else:
410
+ tokenizer = AutoTokenizer.from_pretrained(
411
+ model_path,
412
+ trust_remote_code=True
413
+ )
414
+
415
+ yield "๐Ÿ”„ Cleaning memory cache..."
416
+
417
+ # Clean memory
418
+ gc.collect()
419
+ if torch.cuda.is_available():
420
+ torch.cuda.empty_cache()
421
+
422
+ # Determine torch dtype
423
+ if quantization in ["4bit", "8bit"]:
424
+ torch_dtype = torch.bfloat16
425
+ elif quantization == "f16":
426
+ torch_dtype = torch.float16
427
+ else: # bf16
428
+ torch_dtype = torch.bfloat16
429
+
430
+ yield "๐Ÿ”„ Loading model weights (this may take a while)..."
431
+
432
+ # Simple GPU-only model loading parameters
433
+ model_kwargs = {
434
+ "device_map": "auto",
435
+ "max_memory": {0: max_memory} if torch.cuda.is_available() else None,
436
+ "torch_dtype": torch_dtype,
437
+ "quantization_config": bnb_config if quantization in ["4bit", "8bit"] else None,
438
+ "trust_remote_code": True,
439
+ }
440
+
441
+ # Memory optimization specific settings (GPU-only)
442
+ if memory_optimization:
443
+ model_kwargs.update({
444
+ "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else "sdpa",
445
+ "use_cache": False, # Disable cache by default for memory optimization
446
+ })
447
+ else:
448
+ model_kwargs.update({
449
+ "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else "sdpa",
450
+ #"use_cache": True, # Enable cache for performance
451
+ })
452
+
453
+ # Add local files only for local models
454
+ if model_source == "Local Path":
455
+ model_kwargs["local_files_only"] = True
456
+
457
+ # Load model
458
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
459
+
460
+ # Post-loading memory optimization
461
+ if memory_optimization:
462
+ yield "๐Ÿ”„ Applying memory optimizations..."
463
+
464
+ # Additional memory cleanup after loading
465
+ gc.collect()
466
+ if torch.cuda.is_available():
467
+ torch.cuda.empty_cache()
468
+ torch.cuda.synchronize()
469
+
470
+ logger.info("Model loaded successfully with memory optimization")
471
+ yield "โœ… Model loaded successfully with memory optimization!" if memory_optimization else "โœ… Model loaded successfully!"
472
+
473
+ except Exception as e:
474
+ logger.error(f"Error loading model: {str(e)}", exc_info=True)
475
+ yield f"โŒ Error loading model: {str(e)}"
476
+
477
+ def unload_model():
478
+ """Unload the model and free memory with aggressive cleanup"""
479
+ global model, tokenizer, MODEL_NAME
480
+
481
+ if model is None:
482
+ return "No model loaded"
483
+
484
+ try:
485
+ logger.info("Unloading model with aggressive memory cleanup...")
486
+
487
+ # Step 1: Move model to CPU first (if it was on GPU)
488
+ if torch.cuda.is_available() and hasattr(model, 'device'):
489
+ try:
490
+ model.cpu()
491
+ logger.info("Model moved to CPU")
492
+ except Exception as cpu_error:
493
+ logger.warning(f"Could not move model to CPU: {cpu_error}")
494
+
495
+ # Step 2: Clear model cache if available
496
+ if hasattr(model, 'clear_cache'):
497
+ model.clear_cache()
498
+
499
+ # Step 3: Delete model and tokenizer references
500
+ del model
501
+ del tokenizer
502
+ model = None
503
+ tokenizer = None
504
+
505
+ # Step 4: Reset model name
506
+ MODEL_NAME = "No Model Loaded"
507
+
508
+ # Step 5: Clear metadata deque
509
+ generation_metadata.clear()
510
+
511
+ # Step 6: Aggressive garbage collection (multiple rounds)
512
+ for i in range(5): # More aggressive - 5 rounds
513
+ gc.collect()
514
+ time.sleep(0.1) # Small delay between rounds
515
+
516
+ # Step 7: Aggressive CUDA cleanup
517
+ if torch.cuda.is_available():
518
+ logger.info("Performing aggressive CUDA cleanup...")
519
+
520
+ # Multiple rounds of cache clearing
521
+ for i in range(5):
522
+ torch.cuda.empty_cache()
523
+ torch.cuda.synchronize()
524
+
525
+ # Additional PyTorch CUDA cleanup
526
+ if hasattr(torch.cuda, 'ipc_collect'):
527
+ torch.cuda.ipc_collect()
528
+
529
+ # Reset memory stats
530
+ if hasattr(torch.cuda, 'reset_peak_memory_stats'):
531
+ torch.cuda.reset_peak_memory_stats()
532
+ if hasattr(torch.cuda, 'reset_accumulated_memory_stats'):
533
+ torch.cuda.reset_accumulated_memory_stats()
534
+
535
+ time.sleep(0.1)
536
+
537
+ # Step 8: Force PyTorch to release all unused memory
538
+ if torch.cuda.is_available():
539
+ try:
540
+ # Try to trigger the memory pool cleanup
541
+ torch.cuda.empty_cache()
542
+
543
+ # Force a small allocation and deallocation to trigger cleanup
544
+ dummy_tensor = torch.zeros(1, device='cuda')
545
+ del dummy_tensor
546
+ torch.cuda.empty_cache()
547
+
548
+ logger.info("Forced memory pool cleanup")
549
+ except Exception as cleanup_error:
550
+ logger.warning(f"Advanced cleanup failed: {cleanup_error}")
551
+
552
+ # Step 9: Final garbage collection
553
+ gc.collect()
554
+
555
+ logger.info("Model unloaded successfully with aggressive cleanup")
556
+ return "โœ… Model unloaded with aggressive memory cleanup"
557
+
558
+ except Exception as e:
559
+ logger.error(f"Error unloading model: {str(e)}", exc_info=True)
560
+ # Emergency cleanup even if unload fails
561
+ model = None
562
+ tokenizer = None
563
+ MODEL_NAME = "No Model Loaded"
564
+ generation_metadata.clear()
565
+
566
+ # Emergency memory cleanup
567
+ for _ in range(3):
568
+ gc.collect()
569
+ if torch.cuda.is_available():
570
+ torch.cuda.empty_cache()
571
+
572
+ return f"โŒ Error unloading model: {str(e)} (Emergency cleanup performed)"
573
+
574
+ def cleanup_memory():
575
+ """Enhanced memory cleanup function with PyTorch optimizations"""
576
+ try:
577
+ # Clear Python garbage
578
+ gc.collect()
579
+
580
+ # Clear CUDA cache if available
581
+ if torch.cuda.is_available():
582
+ # Multiple aggressive cleanup rounds
583
+ for i in range(3):
584
+ torch.cuda.empty_cache()
585
+ torch.cuda.synchronize()
586
+ if hasattr(torch.cuda, 'ipc_collect'):
587
+ torch.cuda.ipc_collect()
588
+
589
+ # PyTorch specific memory management
590
+ if hasattr(torch.cuda, 'reset_peak_memory_stats'):
591
+ torch.cuda.reset_peak_memory_stats()
592
+ if hasattr(torch.cuda, 'reset_accumulated_memory_stats'):
593
+ torch.cuda.reset_accumulated_memory_stats()
594
+
595
+ # Brief pause between cleanup rounds
596
+ time.sleep(0.1)
597
+
598
+ # Clear metadata deque
599
+ generation_metadata.clear()
600
+
601
+ # Force garbage collection again
602
+ gc.collect()
603
+
604
+ logger.info("Enhanced memory cleanup completed")
605
+ return "๐Ÿงน Enhanced memory cleanup completed"
606
+ except Exception as e:
607
+ logger.error(f"Memory cleanup error: {e}")
608
+ return f"Memory cleanup error: {e}"
609
+
610
+ def nuclear_memory_cleanup():
611
+ """Nuclear option: Complete VRAM reset (use if normal unload doesn't work)"""
612
+ global model, tokenizer, MODEL_NAME
613
+
614
+ try:
615
+ logger.info("Performing nuclear memory cleanup...")
616
+
617
+ # Force unload everything
618
+ model = None
619
+ tokenizer = None
620
+ MODEL_NAME = "No Model Loaded"
621
+ generation_metadata.clear()
622
+
623
+ # Import PyTorch again to reset some internal states
624
+ import torch
625
+
626
+ # Multiple aggressive cleanup rounds
627
+ for round_num in range(10): # Very aggressive - 10 rounds
628
+ gc.collect()
629
+
630
+ if torch.cuda.is_available():
631
+ # Multiple types of CUDA cleanup
632
+ torch.cuda.empty_cache()
633
+ torch.cuda.synchronize()
634
+
635
+ # Try to reset CUDA context
636
+ try:
637
+ if hasattr(torch.cuda, 'ipc_collect'):
638
+ torch.cuda.ipc_collect()
639
+ if hasattr(torch.cuda, 'memory_summary'):
640
+ logger.info(f"Round {round_num + 1}: {torch.cuda.memory_summary()}")
641
+ except Exception:
642
+ pass
643
+
644
+ # Reset memory stats
645
+ try:
646
+ if hasattr(torch.cuda, 'reset_peak_memory_stats'):
647
+ torch.cuda.reset_peak_memory_stats()
648
+ if hasattr(torch.cuda, 'reset_accumulated_memory_stats'):
649
+ torch.cuda.reset_accumulated_memory_stats()
650
+ except Exception:
651
+ pass
652
+
653
+ time.sleep(0.1)
654
+
655
+ # Final attempt: allocate and free a small tensor to trigger cleanup
656
+ if torch.cuda.is_available():
657
+ try:
658
+ for _ in range(5):
659
+ dummy = torch.zeros(1024, 1024, device='cuda') # 4MB tensor
660
+ del dummy
661
+ torch.cuda.empty_cache()
662
+ torch.cuda.synchronize()
663
+ except Exception as nuclear_error:
664
+ logger.warning(f"Nuclear tensor cleanup failed: {nuclear_error}")
665
+
666
+ logger.info("Nuclear memory cleanup completed")
667
+ return "โ˜ข๏ธ Nuclear memory cleanup completed! VRAM should be minimal now."
668
+
669
+ except Exception as e:
670
+ logger.error(f"Nuclear cleanup error: {e}")
671
+ return f"โ˜ข๏ธ Nuclear cleanup error: {e}"
672
+
673
+ def get_memory_stats():
674
+ """Get comprehensive VRAM usage information"""
675
+ if not torch.cuda.is_available():
676
+ return """
677
+ <div style="text-align: center; padding: 15px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;">
678
+ <h3 style="margin: 0; font-size: 16px;">๐Ÿ’ป CPU Mode</h3>
679
+ <p style="margin: 5px 0; opacity: 0.9;">GPU not available</p>
680
+ </div>
681
+ """
682
+
683
+ try:
684
+ torch.cuda.synchronize()
685
+ total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
686
+ allocated = torch.cuda.memory_allocated(0) / (1024**3)
687
+ reserved = torch.cuda.memory_reserved(0) / (1024**3)
688
+ free = total - reserved
689
+ usage_percent = (reserved/total)*100
690
+
691
+ # Get peak memory if available
692
+ peak_allocated = 0
693
+ if hasattr(torch.cuda, 'max_memory_allocated'):
694
+ peak_allocated = torch.cuda.max_memory_allocated(0) / (1024**3)
695
+
696
+ # Dynamic color based on usage
697
+ if usage_percent < 50:
698
+ color = "#10b981" # Green
699
+ elif usage_percent < 80:
700
+ color = "#f59e0b" # Orange
701
+ else:
702
+ color = "#ef4444" # Red
703
+
704
+ return f"""
705
+ <div style="text-align: center; padding: 15px; background: linear-gradient(135deg, {color}22 0%, {color}44 100%); border: 2px solid {color}; border-radius: 10px;">
706
+ <h3 style="margin: 0; font-size: 16px; color: {color};">๐ŸŽฎ VRAM Usage</h3>
707
+ <div style="margin: 10px 0;">
708
+ <div style="background: #f3f4f6; border-radius: 8px; height: 8px; overflow: hidden;">
709
+ <div style="width: {usage_percent}%; height: 100%; background: {color}; transition: width 0.3s ease;"></div>
710
+ </div>
711
+ </div>
712
+ <p style="margin: 5px 0; font-weight: 600;">Total: {total:.2f} GB</p>
713
+ <p style="margin: 5px 0;">Allocated: {allocated:.2f} GB ({usage_percent:.1f}%)</p>
714
+ <p style="margin: 5px 0;">Reserved: {reserved:.2f} GB</p>
715
+ <p style="margin: 5px 0;">Free: {free:.2f} GB</p>
716
+ <p style="margin: 5px 0; font-size: 12px; opacity: 0.8;">Peak: {peak_allocated:.2f} GB</p>
717
+ <p style="margin: 5px 0; font-size: 10px; opacity: 0.6;">RAM-Saving Streamer Active</p>
718
+ </div>
719
+ """
720
+ except Exception as e:
721
+ logger.error(f"Error getting memory stats: {str(e)}")
722
+ return f"""
723
+ <div style="text-align: center; padding: 15px; background: #fee2e2; border: 2px solid #ef4444; border-radius: 10px;">
724
+ <h3 style="margin: 0; color: #ef4444;">โŒ Error</h3>
725
+ <p style="margin: 5px 0;">{str(e)}</p>
726
+ </div>
727
+ """
728
+
729
+ def process_latex_content(text):
730
+ """Enhanced LaTeX processing for streaming without UI glitches"""
731
+ # Don't process LaTeX here - let Gradio handle it natively
732
+ # Just return the text as-is for now
733
+ return text
734
+
735
+ def process_think_tags(text):
736
+ """Process thinking tags with progressive streaming support"""
737
+ # Check if we're in the middle of generating a think section
738
+ if '<think>' in text and '</think>' not in text:
739
+ # We're currently generating inside a think section
740
+ parts = text.split('<think>')
741
+ if len(parts) == 2:
742
+ before_think = parts[0]
743
+ thinking_content = parts[1]
744
+
745
+ # Create a progressive thinking display
746
+ formatted_thinking = f"""
747
+ <div style="background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-left: 4px solid #6366f1; padding: 12px; margin: 8px 0; border-radius: 8px;">
748
+ <div style="display: flex; align-items: center; margin-bottom: 8px;">
749
+ <span style="font-size: 16px; margin-right: 8px;">๐Ÿค”</span>
750
+ <strong style="color: #4338ca;">Thinking...</strong>
751
+ </div>
752
+ <div style="color: #475569; font-style: italic;">{thinking_content}</div>
753
+ </div>
754
+
755
+ """
756
+ return before_think + formatted_thinking
757
+
758
+ # Handle completed think sections
759
+ think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
760
+
761
+ def replace_think(match):
762
+ think_content = match.group(1).strip()
763
+ return f"""
764
+ <div style="background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-left: 4px solid #6366f1; padding: 12px; margin: 8px 0; border-radius: 8px;">
765
+ <div style="display: flex; align-items: center; margin-bottom: 8px;">
766
+ <span style="font-size: 16px; margin-right: 8px;">๐Ÿค”</span>
767
+ <strong style="color: #4338ca;">Thinking...</strong>
768
+ </div>
769
+ <div style="color: #475569; font-style: italic;">{think_content}</div>
770
+ </div>
771
+
772
+ """
773
+
774
+ # Replace completed <think> tags with formatted version
775
+ processed_text = think_pattern.sub(replace_think, text)
776
+
777
+ return processed_text
778
+
779
+ def calculate_generation_metrics(start_time, total_tokens):
780
+ """Calculate generation metrics"""
781
+ end_time = time.time()
782
+ generation_time = end_time - start_time
783
+ tokens_per_second = total_tokens / generation_time if generation_time > 0 else 0
784
+
785
+ return {
786
+ "generation_time": generation_time,
787
+ "total_tokens": total_tokens,
788
+ "tokens_per_second": tokens_per_second,
789
+ "model_name": MODEL_NAME
790
+ }
791
+
792
+ def format_metadata_tooltip(metadata):
793
+ """Format metadata for tooltip display"""
794
+ return f"""Model: {metadata['model_name']}
795
+ Tokens: {metadata['total_tokens']}
796
+ Speed: {metadata['tokens_per_second']:.2f} tok/s
797
+ Time: {metadata['generation_time']:.2f}s"""
798
+
799
+ def add_metadata_to_response(response_text, metadata):
800
+ """Add metadata icon with tooltip to the response"""
801
+ tooltip_content = format_metadata_tooltip(metadata)
802
+
803
+ # Create a metadata icon with tooltip using HTML
804
+ metadata_html = f"""
805
+ <div style="position: relative; display: inline-block; margin-left: 8px;">
806
+ <span class="metadata-icon" style="cursor: help; opacity: 0.6; font-size: 14px;" title="{tooltip_content}">โ„น๏ธ</span>
807
+ </div>
808
+ """
809
+
810
+ # Add metadata icon at the end of the response
811
+ return response_text + "\n\n" + metadata_html
812
+
813
+ def chat_with_model(message, history, system_prompt, temp, top_p_val, top_k_val, rep_penalty_val, memory_opt):
814
+ """
815
+ Enhanced chat function with RAM-saving streamer and improved memory management.
816
+ Uses direct generation approach for better memory control and VRAM efficiency.
817
+ """
818
+ global model, tokenizer, generation_metadata
819
+
820
+ # Check if model is loaded
821
+ if model is None or tokenizer is None:
822
+ return "โŒ Model not loaded. Please load the model first."
823
+
824
+ # Initialize variables for cleanup
825
+ input_ids = None
826
+ streamer = None
827
+
828
+ try:
829
+ # Record start time for metrics
830
+ start_time = time.time()
831
+ token_count = 0
832
+
833
+ # Format conversation for model
834
+ messages = [{"role": "system", "content": system_prompt}]
835
+
836
+ # Add chat history - HANDLE BOTH FORMATS (tuples from original and dicts from new)
837
+ for h in history:
838
+ if isinstance(h, dict):
839
+ # New dict format
840
+ if h.get("role") == "user":
841
+ messages.append({"role": "user", "content": h["content"]})
842
+ elif h.get("role") == "assistant":
843
+ messages.append({"role": "assistant", "content": h["content"]})
844
+ else:
845
+ # Original tuple format (user_msg, bot_msg)
846
+ if len(h) >= 2:
847
+ messages.append({"role": "user", "content": h[0]})
848
+ if h[1] is not None:
849
+ messages.append({"role": "assistant", "content": h[1]})
850
+
851
+ # Add the current message
852
+ messages.append({"role": "user", "content": message})
853
+
854
+ # Wrap generation in torch.no_grad() to prevent gradient accumulation
855
+ with torch.no_grad():
856
+ # Create model input with memory-efficient approach
857
+ input_ids = tokenizer.apply_chat_template(
858
+ messages,
859
+ tokenize=True,
860
+ add_generation_prompt=True,
861
+ return_tensors="pt"
862
+ )
863
+
864
+ # Handle edge case
865
+ if input_ids.ndim == 1:
866
+ input_ids = input_ids.unsqueeze(0)
867
+
868
+ # Move to device
869
+ input_ids = input_ids.to(model.device)
870
+
871
+ # Setup RAM-saving streamer
872
+ streamer = RAMSavingIteratorStreamer(
873
+ tokenizer,
874
+ skip_special_tokens=True,
875
+ skip_prompt=True,
876
+ timeout=1.0
877
+ )
878
+
879
+ # Set prompt length for the streamer
880
+ streamer.set_prompt_length(input_ids.shape[1])
881
+
882
+ # Pre-generation memory cleanup (only if memory optimization is on)
883
+ if memory_opt:
884
+ gc.collect()
885
+ if torch.cuda.is_available():
886
+ torch.cuda.empty_cache()
887
+
888
+ # Conditional generation parameters based on memory optimization
889
+ gen_kwargs = {
890
+ "input_ids": input_ids,
891
+ "max_new_tokens": MAX_LENGTH,
892
+ "temperature": temp,
893
+ "top_p": top_p_val,
894
+ "top_k": top_k_val,
895
+ "repetition_penalty": rep_penalty_val,
896
+ "do_sample": temp > 0,
897
+ "streamer": streamer,
898
+ "use_cache": not memory_opt, # Disable cache only if memory optimization is on
899
+ }
900
+
901
+ # Generate in a thread for real-time streaming
902
+ thread = Thread(
903
+ target=model.generate,
904
+ kwargs=gen_kwargs,
905
+ daemon=True
906
+ )
907
+ thread.start()
908
+
909
+ # Stream the response with conditional memory management
910
+ partial_text = ""
911
+ try:
912
+ for new_text in streamer:
913
+ partial_text += new_text
914
+ token_count += 1
915
+
916
+ # Process the text to handle think tags while preserving LaTeX
917
+ processed_text = process_think_tags(partial_text)
918
+
919
+ yield processed_text
920
+
921
+ # Conditional cleanup based on memory optimization setting (less frequent)
922
+ if memory_opt and token_count % 150 == 0: # Reduced frequency for performance
923
+ gc.collect() # Only light cleanup if memory optimization is on
924
+
925
+ except StopIteration:
926
+ # Normal end of generation
927
+ pass
928
+ except Exception as stream_error:
929
+ logger.error(f"Streaming error: {stream_error}")
930
+ yield f"โŒ Streaming error: {stream_error}"
931
+ return
932
+
933
+ finally:
934
+ # Add metadata to final response
935
+ try:
936
+ metrics = calculate_generation_metrics(start_time, token_count)
937
+ partial_text = add_metadata_to_response(partial_text, metrics)
938
+ except Exception as e:
939
+ logger.warning(f"Couldn't add metadata: {str(e)}")
940
+
941
+ yield partial_text
942
+
943
+ # Ensure thread completion
944
+ if thread.is_alive():
945
+ thread.join(timeout=5.0)
946
+ if thread.is_alive():
947
+ logger.warning("Generation thread did not complete in time")
948
+
949
+ # Calculate generation metrics
950
+ try:
951
+ metrics = calculate_generation_metrics(start_time, token_count)
952
+
953
+ # Store metadata (using deque with max size to prevent memory leaks)
954
+ generation_metadata.append(metrics)
955
+
956
+ # Log the metrics
957
+ logger.info(f"Generation metrics - Tokens: {metrics['total_tokens']}, Speed: {metrics['tokens_per_second']:.2f} tok/s, Time: {metrics['generation_time']:.2f}s")
958
+ except Exception as metrics_error:
959
+ logger.warning(f"Error calculating metrics: {metrics_error}")
960
+
961
+ # Final cleanup
962
+ try:
963
+ # Clean up streamer
964
+ if streamer:
965
+ streamer.cleanup()
966
+ del streamer
967
+ streamer = None
968
+
969
+ # Clean up input tensors
970
+ if input_ids is not None:
971
+ del input_ids
972
+ input_ids = None
973
+
974
+ # Conditional cleanup based on memory optimization setting
975
+ if memory_opt:
976
+ # Aggressive cleanup only if memory optimization is enabled
977
+ if torch.cuda.is_available():
978
+ for _ in range(2): # Reduced rounds for performance
979
+ torch.cuda.empty_cache()
980
+ torch.cuda.synchronize()
981
+ # Force garbage collection
982
+ for _ in range(2):
983
+ gc.collect()
984
+ else:
985
+ # Light cleanup for performance mode
986
+ gc.collect()
987
+ if torch.cuda.is_available():
988
+ torch.cuda.empty_cache()
989
+
990
+ logger.info(f"Generation completed, {token_count} tokens, memory_opt: {memory_opt}, VRAM saved with RAM-saving streamer")
991
+
992
+ except Exception as cleanup_error:
993
+ logger.warning(f"Final cleanup warning: {cleanup_error}")
994
+
995
+ except Exception as e:
996
+ logger.error(f"Error in chat_with_model: {str(e)}", exc_info=True)
997
+
998
+ # Emergency cleanup
999
+ try:
1000
+ if streamer:
1001
+ streamer.cleanup()
1002
+ del streamer
1003
+ if input_ids is not None:
1004
+ del input_ids
1005
+ gc.collect()
1006
+ if torch.cuda.is_available():
1007
+ torch.cuda.empty_cache()
1008
+ except Exception as emergency_cleanup_error:
1009
+ logger.error(f"Emergency cleanup failed: {emergency_cleanup_error}")
1010
+
1011
+ yield f"โŒ Error: {str(e)}"
1012
+
1013
+ def update_model_name():
1014
+ """Update the displayed model name"""
1015
+ return f"๐Ÿ”ฎ AI Chat Assistant ({MODEL_NAME})"
1016
+
1017
+ def add_page_refresh_warning():
1018
+ """Add JavaScript to warn about page refresh when model is loaded"""
1019
+ return """
1020
+ <script>
1021
+ window.addEventListener('beforeunload', function (e) {
1022
+ // Check if model is loaded by looking for specific text in the page
1023
+ const statusElements = document.querySelectorAll('input[type="text"], textarea');
1024
+ let modelLoaded = false;
1025
+
1026
+ statusElements.forEach(element => {
1027
+ if (element.value && element.value.includes('Model loaded successfully')) {
1028
+ modelLoaded = true;
1029
+ }
1030
+ });
1031
+
1032
+ if (modelLoaded) {
1033
+ e.preventDefault();
1034
+ e.returnValue = 'A model is currently loaded. Are you sure you want to leave?';
1035
+ return 'A model is currently loaded. Are you sure you want to leave?';
1036
+ }
1037
+ });
1038
+ </script>
1039
+ """
1040
+
1041
+ # Custom CSS for elegant styling with fixed dropdown behavior
1042
+ custom_css = """
1043
+ /* Main container styling */
1044
+ .gradio-container {
1045
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
1046
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
1047
+ min-height: 100vh;
1048
+ }
1049
+
1050
+ /* Header styling */
1051
+ .header-text {
1052
+ background: rgba(255, 255, 255, 0.95);
1053
+ backdrop-filter: blur(10px);
1054
+ border-radius: 15px;
1055
+ padding: 20px;
1056
+ margin: 20px 0;
1057
+ text-align: center;
1058
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
1059
+ border: 1px solid rgba(255, 255, 255, 0.2);
1060
+ }
1061
+
1062
+ /* Chat interface styling */
1063
+ .chat-container {
1064
+ background: rgba(255, 255, 255, 0.95) !important;
1065
+ border-radius: 20px !important;
1066
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1) !important;
1067
+ border: 1px solid rgba(255, 255, 255, 0.2) !important;
1068
+ backdrop-filter: blur(10px) !important;
1069
+ }
1070
+
1071
+ /* Control panel styling */
1072
+ .control-panel {
1073
+ background: rgba(255, 255, 255, 0.9) !important;
1074
+ border-radius: 15px !important;
1075
+ padding: 20px !important;
1076
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1) !important;
1077
+ border: 1px solid rgba(255, 255, 255, 0.3) !important;
1078
+ backdrop-filter: blur(10px) !important;
1079
+ overflow: visible !important; /* Allow dropdowns to overflow */
1080
+ }
1081
+
1082
+ /* Button styling */
1083
+ .btn-primary {
1084
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
1085
+ border: none !important;
1086
+ border-radius: 10px !important;
1087
+ color: white !important;
1088
+ font-weight: 600 !important;
1089
+ transition: all 0.3s ease !important;
1090
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
1091
+ }
1092
+
1093
+ .btn-primary:hover {
1094
+ transform: translateY(-2px) !important;
1095
+ box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important;
1096
+ }
1097
+
1098
+ .btn-secondary {
1099
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important;
1100
+ border: none !important;
1101
+ border-radius: 10px !important;
1102
+ color: white !important;
1103
+ font-weight: 600 !important;
1104
+ transition: all 0.3s ease !important;
1105
+ }
1106
+
1107
+ /* Input field styling */
1108
+ .input-field {
1109
+ border-radius: 10px !important;
1110
+ border: 2px solid rgba(102, 126, 234, 0.2) !important;
1111
+ transition: all 0.3s ease !important;
1112
+ }
1113
+
1114
+ .input-field:focus {
1115
+ border-color: #667eea !important;
1116
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
1117
+ }
1118
+
1119
+ /* Dropdown fixes */
1120
+ .dropdown-container {
1121
+ position: relative !important;
1122
+ z-index: 1000 !important;
1123
+ overflow: visible !important;
1124
+ }
1125
+
1126
+ /* Fix dropdown menu positioning and styling */
1127
+ .dropdown select,
1128
+ .dropdown-menu,
1129
+ .svelte-select,
1130
+ .svelte-select-list {
1131
+ position: relative !important;
1132
+ z-index: 1001 !important;
1133
+ background: white !important;
1134
+ border: 2px solid rgba(102, 126, 234, 0.2) !important;
1135
+ border-radius: 10px !important;
1136
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15) !important;
1137
+ max-height: 200px !important;
1138
+ overflow-y: auto !important;
1139
+ }
1140
+
1141
+ /* Fix dropdown option styling */
1142
+ .dropdown option,
1143
+ .svelte-select-option {
1144
+ padding: 8px 12px !important;
1145
+ background: white !important;
1146
+ color: #333 !important;
1147
+ border: none !important;
1148
+ }
1149
+
1150
+ .dropdown option:hover,
1151
+ .svelte-select-option:hover {
1152
+ background: #f0f0f0 !important;
1153
+ color: #667eea !important;
1154
+ }
1155
+
1156
+ /* Ensure dropdown arrow is clickable */
1157
+ .dropdown::after,
1158
+ .dropdown-arrow {
1159
+ pointer-events: none !important;
1160
+ z-index: 1002 !important;
1161
+ }
1162
+
1163
+ /* Fix any overflow issues in parent containers */
1164
+ .gradio-group,
1165
+ .gradio-column {
1166
+ overflow: visible !important;
1167
+ }
1168
+
1169
+ /* Accordion styling */
1170
+ .accordion {
1171
+ border-radius: 10px !important;
1172
+ border: 1px solid rgba(102, 126, 234, 0.2) !important;
1173
+ overflow: visible !important; /* Allow dropdowns to overflow accordion */
1174
+ }
1175
+
1176
+ /* Status indicators */
1177
+ .status-success {
1178
+ color: #10b981 !important;
1179
+ font-weight: 600 !important;
1180
+ }
1181
+
1182
+ .status-error {
1183
+ color: #ef4444 !important;
1184
+ font-weight: 600 !important;
1185
+ }
1186
+
1187
+ /* Reduced transition frequency to avoid conflicts */
1188
+ .gradio-container * {
1189
+ transition: background-color 0.3s ease, border-color 0.3s ease !important;
1190
+ }
1191
+
1192
+ /* Chat bubble styling */
1193
+ .message {
1194
+ border-radius: 18px !important;
1195
+ padding: 12px 16px !important;
1196
+ margin: 8px 0 !important;
1197
+ max-width: 80% !important;
1198
+ }
1199
+
1200
+ .user-message {
1201
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
1202
+ color: white !important;
1203
+ margin-left: auto !important;
1204
+ }
1205
+
1206
+ .bot-message {
1207
+ background: #f8fafc !important;
1208
+ border: 1px solid #e2e8f0 !important;
1209
+ }
1210
+
1211
+ /* Metadata tooltip styling - Enhanced */
1212
+ .metadata-icon {
1213
+ display: inline-block;
1214
+ margin-left: 8px;
1215
+ cursor: help;
1216
+ opacity: 0.6;
1217
+ transition: opacity 0.3s ease, transform 0.2s ease;
1218
+ font-size: 14px;
1219
+ user-select: none;
1220
+ vertical-align: middle;
1221
+ }
1222
+
1223
+ .metadata-icon:hover {
1224
+ opacity: 1;
1225
+ transform: scale(1.1);
1226
+ }
1227
+
1228
+ /* Enhanced tooltip styling */
1229
+ .metadata-icon[title]:hover::after {
1230
+ content: attr(title);
1231
+ position: absolute;
1232
+ bottom: 100%;
1233
+ left: 50%;
1234
+ transform: translateX(-50%);
1235
+ background: rgba(0, 0, 0, 0.9);
1236
+ color: white;
1237
+ padding: 8px 12px;
1238
+ border-radius: 6px;
1239
+ font-size: 12px;
1240
+ white-space: pre-line;
1241
+ z-index: 1000;
1242
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
1243
+ margin-bottom: 5px;
1244
+ min-width: 200px;
1245
+ text-align: left;
1246
+ }
1247
+
1248
+ .metadata-icon[title]:hover::before {
1249
+ content: '';
1250
+ position: absolute;
1251
+ bottom: 100%;
1252
+ left: 50%;
1253
+ transform: translateX(-50%);
1254
+ border: 5px solid transparent;
1255
+ border-top-color: rgba(0, 0, 0, 0.9);
1256
+ z-index: 1001;
1257
+ }
1258
+
1259
+ /* Compact system prompt */
1260
+ .compact-prompt {
1261
+ min-height: 40px !important;
1262
+ transition: min-height 0.3s ease !important;
1263
+ }
1264
+
1265
+ .compact-prompt:focus {
1266
+ min-height: 80px !important;
1267
+ }
1268
+ """
1269
+
1270
+ # Main application
1271
+ with gr.Blocks(css=custom_css, title="๐Ÿ”ฎ AI Chat Assistant") as demo:
1272
+ # Add page refresh warning script
1273
+ gr.HTML(add_page_refresh_warning())
1274
+
1275
+ # Header
1276
+ with gr.Row():
1277
+ title = gr.Markdown("# ๐Ÿ”ฎ AI Chat Assistant (No Model Loaded)", elem_classes="header-text")
1278
+
1279
+ with gr.Row(equal_height=True):
1280
+ # Main chat area (left side - 70% width)
1281
+ with gr.Column(scale=7, elem_classes="chat-container"):
1282
+ # Compact system prompt (changed from 4 lines to 1)
1283
+ system_prompt = gr.Textbox(
1284
+ label="๐ŸŽฏ System Prompt",
1285
+ value="You are a helpful AI assistant.",
1286
+ lines=1, # Changed from 4 to 1
1287
+ elem_classes="input-field compact-prompt"
1288
+ )
1289
+
1290
+ # Generation settings in accordion
1291
+ with gr.Accordion("โš™๏ธ Generation Settings", open=False, elem_classes="accordion"):
1292
+ with gr.Row():
1293
+ temperature = gr.Slider(0.0, 2.0, DEFAULT_TEMPERATURE, step=0.05, label="๐ŸŒก๏ธ Temperature")
1294
+ top_p = gr.Slider(0.0, 1.0, DEFAULT_TOP_P, step=0.01, label="๐ŸŽฏ Top-p")
1295
+ with gr.Row():
1296
+ top_k = gr.Slider(1, 200, DEFAULT_TOP_K, step=1, label="๐Ÿ” Top-k")
1297
+ rep_penalty = gr.Slider(1.0, 2.0, DEFAULT_REP_PENALTY, step=0.01, label="๐Ÿ”„ Repetition Penalty")
1298
+
1299
+ # Memory optimization for chat (moved here to be defined before use)
1300
+ memory_opt_chat = gr.Checkbox(
1301
+ label="๐Ÿง  Memory Optimization for Chat",
1302
+ value=True,
1303
+ info="Use memory optimization during chat generation (disables KV cache)"
1304
+ )
1305
+
1306
+ # Chat interface using original gr.ChatInterface for fast streaming and stop button
1307
+ chatbot = gr.Chatbot(
1308
+ height=500,
1309
+ latex_delimiters=[
1310
+ {"left": "$", "right": "$", "display": True},
1311
+ {"left": "$", "right": "$", "display": False},
1312
+ {"left": "\\(", "right": "\\)", "display": False},
1313
+ {"left": "\\[", "right": "\\]", "display": True}
1314
+ ],
1315
+ show_copy_button=True,
1316
+ avatar_images=("๐Ÿ‘ค", "๐Ÿค–"),
1317
+ type="messages",
1318
+ render_markdown=True
1319
+ )
1320
+
1321
+ chat_interface = gr.ChatInterface(
1322
+ fn=chat_with_model,
1323
+ chatbot=chatbot,
1324
+ additional_inputs=[system_prompt, temperature, top_p, top_k, rep_penalty, memory_opt_chat],
1325
+ type="messages",
1326
+ submit_btn="Send ๐Ÿ“ค",
1327
+ stop_btn="โน๏ธ Stop"
1328
+ )
1329
+
1330
+ # Control panel (right side - 30% width)
1331
+ with gr.Column(scale=3, elem_classes="control-panel"):
1332
+ # Model status and controls
1333
+ with gr.Group():
1334
+ gr.Markdown("### ๐Ÿš€ Model Controls")
1335
+
1336
+ with gr.Row():
1337
+ load_btn = gr.Button("๐Ÿš€ Load Model", variant="primary", elem_classes="btn-primary")
1338
+ unload_btn = gr.Button("๐Ÿ—‘๏ธ Unload", variant="secondary", elem_classes="btn-secondary")
1339
+
1340
+ model_status = gr.Textbox(
1341
+ label="๐Ÿ“Š Status",
1342
+ value="Model not loaded",
1343
+ interactive=False,
1344
+ elem_classes="input-field"
1345
+ )
1346
+
1347
+ progress_display = gr.Textbox(
1348
+ label="๐Ÿ“ˆ Progress",
1349
+ value="Ready to load model",
1350
+ interactive=False,
1351
+ elem_classes="input-field"
1352
+ )
1353
+
1354
+ # Model selection
1355
+ with gr.Group():
1356
+ gr.Markdown("### ๐ŸŽ›๏ธ Model Selection")
1357
+
1358
+ model_source = gr.Radio(
1359
+ choices=["Hugging Face Model", "Local Path"],
1360
+ value="Local Path", # Changed default to Local Path
1361
+ label="๐Ÿ“ Model Source"
1362
+ )
1363
+
1364
+ # HF Model search and selection (initially hidden)
1365
+ with gr.Group(visible=False) as hf_group:
1366
+ model_search = gr.Textbox(
1367
+ label="๐Ÿ” Search Models",
1368
+ placeholder="e.g., microsoft/Phi-3, meta-llama/Llama-3, ykarout/your-model",
1369
+ elem_classes="input-field"
1370
+ )
1371
+
1372
+ hf_model = gr.Dropdown(
1373
+ label="๐Ÿ“‹ Select Model",
1374
+ choices=[],
1375
+ interactive=True,
1376
+ elem_classes="input-field dropdown-container",
1377
+ allow_custom_value=True, # Allow typing custom model names
1378
+ filterable=True # Enable filtering
1379
+ )
1380
+
1381
+ # Local path group (visible by default)
1382
+ with gr.Group(visible=True) as local_group:
1383
+ local_path = gr.Textbox(
1384
+ value=LOCAL_MODELS_BASE, # Changed default to new base location
1385
+ label="๐Ÿ“ Local Models Base Path",
1386
+ elem_classes="input-field"
1387
+ )
1388
+
1389
+ # Button to refresh local models
1390
+ refresh_local_btn = gr.Button("๐Ÿ”„ Scan Local Models", elem_classes="btn-secondary")
1391
+
1392
+ # Dropdown for local models with better configuration
1393
+ local_models_dropdown = gr.Dropdown(
1394
+ label="๐Ÿ“‹ Available Local Models",
1395
+ choices=[],
1396
+ interactive=True,
1397
+ elem_classes="input-field dropdown-container",
1398
+ allow_custom_value=False, # Don't allow custom for local models
1399
+ filterable=True # Enable filtering
1400
+ )
1401
+
1402
+ quantization = gr.Radio(
1403
+ choices=["4bit", "8bit", "bf16", "f16"],
1404
+ value="4bit",
1405
+ label="โšก Quantization"
1406
+ )
1407
+
1408
+ # Advanced memory optimization toggle
1409
+ memory_optimization = gr.Checkbox(
1410
+ label="๐Ÿง  Advanced Memory Optimization",
1411
+ value=True,
1412
+ info="Reduces VRAM usage but may slightly impact speed"
1413
+ )
1414
+
1415
+ # Note: Memory optimization for chat is now in Generation Settings
1416
+
1417
+ # Memory stats with cleanup buttons
1418
+ with gr.Group():
1419
+ gr.Markdown("### ๐Ÿ’พ System Status")
1420
+ memory_info = gr.HTML()
1421
+ with gr.Row():
1422
+ refresh_btn = gr.Button("โ†ป Refresh Stats", elem_classes="btn-secondary")
1423
+ cleanup_btn = gr.Button("๐Ÿงน Clean Memory", elem_classes="btn-secondary")
1424
+ with gr.Row():
1425
+ nuclear_btn = gr.Button("โ˜ข๏ธ Nuclear Cleanup", elem_classes="btn-secondary", variant="stop")
1426
+
1427
+ # Event handlers
1428
+
1429
+ # Model search functionality for HF
1430
+ model_search.change(
1431
+ update_model_dropdown,
1432
+ inputs=[model_search],
1433
+ outputs=[hf_model]
1434
+ )
1435
+
1436
+ # Show/hide model selection based on source
1437
+ def toggle_model_source(choice):
1438
+ return (
1439
+ gr.Group(visible=choice == "Hugging Face Model"),
1440
+ gr.Group(visible=choice == "Local Path")
1441
+ )
1442
+
1443
+ model_source.change(
1444
+ toggle_model_source,
1445
+ inputs=[model_source],
1446
+ outputs=[hf_group, local_group]
1447
+ )
1448
+
1449
+ # Local model scanning
1450
+ refresh_local_btn.click(
1451
+ update_local_models_dropdown,
1452
+ inputs=[local_path],
1453
+ outputs=[local_models_dropdown]
1454
+ )
1455
+
1456
+ # Auto-scan on path change
1457
+ local_path.change(
1458
+ update_local_models_dropdown,
1459
+ inputs=[local_path],
1460
+ outputs=[local_models_dropdown]
1461
+ )
1462
+
1463
+ # Model loading with progress
1464
+ load_btn.click(
1465
+ load_model_with_progress,
1466
+ inputs=[model_source, hf_model, local_path, local_models_dropdown, quantization, memory_optimization],
1467
+ outputs=[progress_display]
1468
+ ).then(
1469
+ lambda: "โœ… Model loaded successfully!" if model is not None else "โŒ Model loading failed",
1470
+ outputs=[model_status]
1471
+ ).then(
1472
+ get_memory_stats,
1473
+ outputs=[memory_info]
1474
+ ).then(
1475
+ update_model_name,
1476
+ outputs=[title]
1477
+ )
1478
+
1479
+ # Model unloading
1480
+ unload_btn.click(
1481
+ unload_model,
1482
+ outputs=[model_status]
1483
+ ).then(
1484
+ lambda: "Ready to load model",
1485
+ outputs=[progress_display]
1486
+ ).then(
1487
+ get_memory_stats,
1488
+ outputs=[memory_info]
1489
+ ).then(
1490
+ lambda: "# ๐Ÿ”ฎ AI Chat Assistant (No Model Loaded)",
1491
+ outputs=[title]
1492
+ )
1493
+
1494
+ # Refresh memory stats
1495
+ refresh_btn.click(get_memory_stats, outputs=[memory_info])
1496
+
1497
+ # Manual memory cleanup
1498
+ cleanup_btn.click(cleanup_memory, outputs=[]).then(
1499
+ get_memory_stats, outputs=[memory_info]
1500
+ )
1501
+
1502
+ # Nuclear memory cleanup
1503
+ nuclear_btn.click(nuclear_memory_cleanup, outputs=[]).then(
1504
+ get_memory_stats, outputs=[memory_info]
1505
+ )
1506
+
1507
+ # Initialize on load
1508
+ demo.load(get_memory_stats, outputs=[memory_info])
1509
+ demo.load(
1510
+ lambda: update_local_models_dropdown(LOCAL_MODELS_BASE),
1511
+ outputs=[local_models_dropdown]
1512
+ )
1513
+
1514
+ # Enable queue for streaming
1515
+ demo.queue()