jbilcke-hf HF staff commited on
Commit
0d34ea8
·
1 Parent(s): c5911ab

add gpu tracking

Browse files
requirements.txt CHANGED
@@ -2,6 +2,7 @@ numpy>=1.26.4
2
 
3
  # to quote a-r-r-o-w/finetrainers:
4
  # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
 
5
  # on some system (Python 3.13+) those do not work:
6
  torch==2.5.1
7
  torchvision==0.20.1
@@ -20,6 +21,9 @@ accelerate
20
  bitsandbytes
21
  peft>=0.12.0
22
 
 
 
 
23
  # eva-decord is missing get_batch it seems
24
  #eva-decord==0.6.1
25
  decord
 
2
 
3
  # to quote a-r-r-o-w/finetrainers:
4
  # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
5
+
6
  # on some system (Python 3.13+) those do not work:
7
  torch==2.5.1
8
  torchvision==0.20.1
 
21
  bitsandbytes
22
  peft>=0.12.0
23
 
24
+ # For GPU monitoring of NVIDIA chipsets
25
+ pynvml
26
+
27
  # eva-decord is missing get_batch it seems
28
  #eva-decord==0.6.1
29
  decord
requirements_without_flash_attention.txt CHANGED
@@ -21,6 +21,10 @@ accelerate
21
  bitsandbytes
22
  peft>=0.12.0
23
 
 
 
 
 
24
  # eva-decord is missing get_batch it seems
25
  eva-decord==0.6.1
26
  # decord
 
21
  bitsandbytes
22
  peft>=0.12.0
23
 
24
+ # For GPU monitoring of NVIDIA chipsets
25
+ you probably won't be able to install that on macOS
26
+ # pynvml
27
+
28
  # eva-decord is missing get_batch it seems
29
  eva-decord==0.6.1
30
  # decord
vms/config.py CHANGED
@@ -150,6 +150,9 @@ DEFAULT_NB_TRAINING_STEPS = 1000
150
  # For this value, it is recommended to use about 20 to 40% of the number of training steps
151
  DEFAULT_NB_LR_WARMUP_STEPS = math.ceil(0.20 * DEFAULT_NB_TRAINING_STEPS) # 20% of training steps
152
 
 
 
 
153
  # For validation
154
  DEFAULT_VALIDATION_NB_STEPS = 50
155
  DEFAULT_VALIDATION_HEIGHT = 512
 
150
  # For this value, it is recommended to use about 20 to 40% of the number of training steps
151
  DEFAULT_NB_LR_WARMUP_STEPS = math.ceil(0.20 * DEFAULT_NB_TRAINING_STEPS) # 20% of training steps
152
 
153
+ # Whether to automatically restart a training job after a server reboot or not
154
+ DEFAULT_AUTO_RESUME = False
155
+
156
  # For validation
157
  DEFAULT_VALIDATION_NB_STEPS = 50
158
  DEFAULT_VALIDATION_HEIGHT = 512
vms/ui/app_ui.py CHANGED
@@ -19,7 +19,8 @@ from vms.config import (
19
  DEFAULT_MAX_GPUS,
20
  DEFAULT_PRECOMPUTATION_ITEMS,
21
  DEFAULT_NB_TRAINING_STEPS,
22
- DEFAULT_NB_LR_WARMUP_STEPS
 
23
  )
24
  from vms.utils import (
25
  get_recommended_precomputation_items,
@@ -40,7 +41,7 @@ from vms.ui.monitoring.services import (
40
  )
41
 
42
  from vms.ui.monitoring.tabs import (
43
- GeneralTab
44
  )
45
 
46
  logger = logging.getLogger(__name__)
@@ -183,6 +184,8 @@ class AppUI:
183
  # Initialize monitoring tab objects
184
  self.monitor_tabs["general_tab"] = GeneralTab(self)
185
 
 
 
186
  # Create tab UI components for monitoring
187
  for tab_id, tab_obj in self.monitor_tabs.items():
188
  tab_obj.create(monitoring_tabs)
@@ -230,7 +233,8 @@ class AppUI:
230
  self.project_tabs["train_tab"].components["current_task_box"],
231
  self.project_tabs["train_tab"].components["num_gpus"],
232
  self.project_tabs["train_tab"].components["precomputation_items"],
233
- self.project_tabs["train_tab"].components["lr_warmup_steps"]
 
234
  ]
235
  )
236
 
@@ -376,6 +380,8 @@ class AppUI:
376
  # Get model_version value
377
  model_version_val = ""
378
 
 
 
379
  # First get the internal model type for the currently selected model
380
  model_internal_type = MODEL_TYPES.get(model_type_val)
381
  logger.info(f"Initializing model version for model_type: {model_type_val} (internal: {model_internal_type})")
@@ -480,7 +486,8 @@ class AppUI:
480
  current_task_val,
481
  num_gpus_val,
482
  precomputation_items_val,
483
- lr_warmup_steps_val
 
484
  )
485
 
486
  def initialize_ui_from_state(self):
 
19
  DEFAULT_MAX_GPUS,
20
  DEFAULT_PRECOMPUTATION_ITEMS,
21
  DEFAULT_NB_TRAINING_STEPS,
22
+ DEFAULT_NB_LR_WARMUP_STEPS,
23
+ DEFAULT_AUTO_RESUME
24
  )
25
  from vms.utils import (
26
  get_recommended_precomputation_items,
 
41
  )
42
 
43
  from vms.ui.monitoring.tabs import (
44
+ GeneralTab, GPUTab
45
  )
46
 
47
  logger = logging.getLogger(__name__)
 
184
  # Initialize monitoring tab objects
185
  self.monitor_tabs["general_tab"] = GeneralTab(self)
186
 
187
+ self.monitor_tabs["gpu_tab"] = GPUTab(self)
188
+
189
  # Create tab UI components for monitoring
190
  for tab_id, tab_obj in self.monitor_tabs.items():
191
  tab_obj.create(monitoring_tabs)
 
233
  self.project_tabs["train_tab"].components["current_task_box"],
234
  self.project_tabs["train_tab"].components["num_gpus"],
235
  self.project_tabs["train_tab"].components["precomputation_items"],
236
+ self.project_tabs["train_tab"].components["lr_warmup_steps"],
237
+ self.project_tabs["train_tab"].components["auto_resume_checkbox"]
238
  ]
239
  )
240
 
 
380
  # Get model_version value
381
  model_version_val = ""
382
 
383
+ auto_resume_val = ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
384
+
385
  # First get the internal model type for the currently selected model
386
  model_internal_type = MODEL_TYPES.get(model_type_val)
387
  logger.info(f"Initializing model version for model_type: {model_type_val} (internal: {model_internal_type})")
 
486
  current_task_val,
487
  num_gpus_val,
488
  precomputation_items_val,
489
+ lr_warmup_steps_val,
490
+ auto_resume_val
491
  )
492
 
493
  def initialize_ui_from_state(self):
vms/ui/monitoring/services/gpu.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU monitoring service for Video Model Studio.
3
+ Tracks NVIDIA GPU resources like utilization, memory, and temperature.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import logging
9
+ from typing import Dict, List, Any, Optional, Tuple
10
+ from collections import deque
11
+ from datetime import datetime
12
+
13
+ # Force the use of the Agg backend which is thread-safe
14
+ import matplotlib
15
+ matplotlib.use('Agg') # Must be before importing pyplot
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ # Optional import of pynvml
23
+ try:
24
+ import pynvml
25
+ PYNVML_AVAILABLE = True
26
+ except ImportError:
27
+ PYNVML_AVAILABLE = False
28
+ logger.info("pynvml not available, GPU monitoring will be limited")
29
+
30
+ class GPUMonitoringService:
31
+ """Service for monitoring NVIDIA GPU resources"""
32
+
33
+ def __init__(self, history_minutes: int = 10, sample_interval: int = 5):
34
+ """Initialize the GPU monitoring service
35
+
36
+ Args:
37
+ history_minutes: How many minutes of history to keep
38
+ sample_interval: How many seconds between samples
39
+ """
40
+ self.history_minutes = history_minutes
41
+ self.sample_interval = sample_interval
42
+ self.max_samples = (history_minutes * 60) // sample_interval
43
+
44
+ # Track if the monitoring thread is running
45
+ self.is_running = False
46
+ self.thread = None
47
+
48
+ # Check if NVIDIA GPUs are available
49
+ self.has_nvidia_gpus = False
50
+ self.gpu_count = 0
51
+ self.device_info = []
52
+ self.history = {}
53
+
54
+ # Try to initialize NVML
55
+ self._initialize_nvml()
56
+
57
+ # Initialize history data structures if GPUs are available
58
+ if self.has_nvidia_gpus:
59
+ self._initialize_history()
60
+
61
+ def _initialize_nvml(self):
62
+ """Initialize NVIDIA Management Library"""
63
+ if not PYNVML_AVAILABLE:
64
+ logger.info("pynvml module not installed, GPU monitoring disabled")
65
+ return
66
+
67
+ try:
68
+ pynvml.nvmlInit()
69
+ self.gpu_count = pynvml.nvmlDeviceGetCount()
70
+ self.has_nvidia_gpus = self.gpu_count > 0
71
+
72
+ if self.has_nvidia_gpus:
73
+ logger.info(f"Successfully initialized NVML, found {self.gpu_count} GPU(s)")
74
+ # Get static information about each GPU
75
+ for i in range(self.gpu_count):
76
+ self.device_info.append(self._get_device_info(i))
77
+ else:
78
+ logger.info("No NVIDIA GPUs found")
79
+
80
+ except Exception as e:
81
+ logger.warning(f"Failed to initialize NVML: {str(e)}")
82
+ self.has_nvidia_gpus = False
83
+
84
+ def _initialize_history(self):
85
+ """Initialize data structures for storing metric history"""
86
+ for i in range(self.gpu_count):
87
+ self.history[i] = {
88
+ 'timestamps': deque(maxlen=self.max_samples),
89
+ 'utilization': deque(maxlen=self.max_samples),
90
+ 'memory_used': deque(maxlen=self.max_samples),
91
+ 'memory_total': deque(maxlen=self.max_samples),
92
+ 'memory_percent': deque(maxlen=self.max_samples),
93
+ 'temperature': deque(maxlen=self.max_samples),
94
+ 'power_usage': deque(maxlen=self.max_samples),
95
+ 'power_limit': deque(maxlen=self.max_samples),
96
+ }
97
+
98
+ def _get_device_info(self, device_index: int) -> Dict[str, Any]:
99
+ """Get static information about a GPU device
100
+
101
+ Args:
102
+ device_index: Index of the GPU device
103
+
104
+ Returns:
105
+ Dictionary with device information
106
+ """
107
+ if not PYNVML_AVAILABLE or not self.has_nvidia_gpus:
108
+ return {"error": "NVIDIA GPUs not available"}
109
+
110
+ try:
111
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
112
+
113
+ # Get device name (decode if it's bytes)
114
+ name = pynvml.nvmlDeviceGetName(handle)
115
+ if isinstance(name, bytes):
116
+ name = name.decode('utf-8')
117
+
118
+ # Get device UUID
119
+ uuid = pynvml.nvmlDeviceGetUUID(handle)
120
+ if isinstance(uuid, bytes):
121
+ uuid = uuid.decode('utf-8')
122
+
123
+ # Get memory info, compute capability
124
+ memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
125
+ compute_capability = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
126
+
127
+ # Get power limits if available
128
+ try:
129
+ power_limit = pynvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000.0 # Convert to watts
130
+ except pynvml.NVMLError:
131
+ power_limit = None
132
+
133
+ return {
134
+ 'index': device_index,
135
+ 'name': name,
136
+ 'uuid': uuid,
137
+ 'memory_total': memory_info.total,
138
+ 'memory_total_gb': memory_info.total / (1024**3), # Convert to GB
139
+ 'compute_capability': f"{compute_capability[0]}.{compute_capability[1]}",
140
+ 'power_limit': power_limit
141
+ }
142
+
143
+ except Exception as e:
144
+ logger.error(f"Error getting device info for GPU {device_index}: {str(e)}")
145
+ return {"error": str(e), "index": device_index}
146
+
147
+ def collect_gpu_metrics(self) -> List[Dict[str, Any]]:
148
+ """Collect current GPU metrics for all available GPUs
149
+
150
+ Returns:
151
+ List of dictionaries with current metrics for each GPU
152
+ """
153
+ if not PYNVML_AVAILABLE or not self.has_nvidia_gpus:
154
+ return []
155
+
156
+ metrics = []
157
+ timestamp = datetime.now()
158
+
159
+ for i in range(self.gpu_count):
160
+ try:
161
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
162
+
163
+ # Get utilization rates
164
+ utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
165
+
166
+ # Get memory information
167
+ memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
168
+
169
+ # Get temperature
170
+ temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
171
+
172
+ # Get power usage if available
173
+ try:
174
+ power_usage = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # Convert to watts
175
+ except pynvml.NVMLError:
176
+ power_usage = None
177
+
178
+ # Get process information
179
+ processes = []
180
+ try:
181
+ for proc in pynvml.nvmlDeviceGetComputeRunningProcesses(handle):
182
+ try:
183
+ process_name = pynvml.nvmlSystemGetProcessName(proc.pid)
184
+ if isinstance(process_name, bytes):
185
+ process_name = process_name.decode('utf-8')
186
+ except pynvml.NVMLError:
187
+ process_name = f"Unknown (PID: {proc.pid})"
188
+
189
+ processes.append({
190
+ 'pid': proc.pid,
191
+ 'name': process_name,
192
+ 'memory_used': proc.usedGpuMemory,
193
+ 'memory_used_mb': proc.usedGpuMemory / (1024**2) # Convert to MB
194
+ })
195
+ except pynvml.NVMLError:
196
+ # Unable to get process information, continue with empty list
197
+ pass
198
+
199
+ gpu_metrics = {
200
+ 'index': i,
201
+ 'timestamp': timestamp,
202
+ 'utilization_gpu': utilization.gpu,
203
+ 'utilization_memory': utilization.memory,
204
+ 'memory_total': memory_info.total,
205
+ 'memory_used': memory_info.used,
206
+ 'memory_free': memory_info.free,
207
+ 'memory_percent': (memory_info.used / memory_info.total) * 100,
208
+ 'temperature': temperature,
209
+ 'power_usage': power_usage,
210
+ 'processes': processes
211
+ }
212
+
213
+ metrics.append(gpu_metrics)
214
+
215
+ except Exception as e:
216
+ logger.error(f"Error collecting metrics for GPU {i}: {str(e)}")
217
+ metrics.append({
218
+ 'index': i,
219
+ 'error': str(e)
220
+ })
221
+
222
+ return metrics
223
+
224
+ def update_history(self):
225
+ """Update GPU metrics history"""
226
+ if not self.has_nvidia_gpus:
227
+ return
228
+
229
+ current_metrics = self.collect_gpu_metrics()
230
+ timestamp = datetime.now()
231
+
232
+ for gpu_metrics in current_metrics:
233
+ if 'error' in gpu_metrics:
234
+ continue
235
+
236
+ idx = gpu_metrics['index']
237
+
238
+ self.history[idx]['timestamps'].append(timestamp)
239
+ self.history[idx]['utilization'].append(gpu_metrics['utilization_gpu'])
240
+ self.history[idx]['memory_used'].append(gpu_metrics['memory_used'])
241
+ self.history[idx]['memory_total'].append(gpu_metrics['memory_total'])
242
+ self.history[idx]['memory_percent'].append(gpu_metrics['memory_percent'])
243
+ self.history[idx]['temperature'].append(gpu_metrics['temperature'])
244
+
245
+ if gpu_metrics['power_usage'] is not None:
246
+ self.history[idx]['power_usage'].append(gpu_metrics['power_usage'])
247
+ else:
248
+ self.history[idx]['power_usage'].append(0)
249
+
250
+ # Store power limit in history (static but kept for consistency)
251
+ info = self.device_info[idx]
252
+ if 'power_limit' in info and info['power_limit'] is not None:
253
+ self.history[idx]['power_limit'].append(info['power_limit'])
254
+ else:
255
+ self.history[idx]['power_limit'].append(0)
256
+
257
+ def start_monitoring(self):
258
+ """Start background thread for collecting GPU metrics"""
259
+ if self.is_running:
260
+ logger.warning("GPU monitoring thread already running")
261
+ return
262
+
263
+ if not self.has_nvidia_gpus:
264
+ logger.info("No NVIDIA GPUs found, not starting monitoring thread")
265
+ return
266
+
267
+ import threading
268
+
269
+ self.is_running = True
270
+
271
+ def _monitor_loop():
272
+ while self.is_running:
273
+ try:
274
+ self.update_history()
275
+ time.sleep(self.sample_interval)
276
+ except Exception as e:
277
+ logger.error(f"Error in GPU monitoring thread: {str(e)}", exc_info=True)
278
+ time.sleep(self.sample_interval)
279
+
280
+ self.thread = threading.Thread(target=_monitor_loop, daemon=True)
281
+ self.thread.start()
282
+ logger.info("GPU monitoring thread started")
283
+
284
+ def stop_monitoring(self):
285
+ """Stop the GPU monitoring thread"""
286
+ if not self.is_running:
287
+ return
288
+
289
+ self.is_running = False
290
+ if self.thread:
291
+ self.thread.join(timeout=1.0)
292
+ logger.info("GPU monitoring thread stopped")
293
+
294
+ def get_gpu_info(self) -> List[Dict[str, Any]]:
295
+ """Get information about all available GPUs
296
+
297
+ Returns:
298
+ List of dictionaries with GPU information
299
+ """
300
+ return self.device_info
301
+
302
+ def get_current_metrics(self) -> List[Dict[str, Any]]:
303
+ """Get current metrics for all GPUs
304
+
305
+ Returns:
306
+ List of dictionaries with current GPU metrics
307
+ """
308
+ return self.collect_gpu_metrics()
309
+
310
+ def generate_utilization_plot(self, gpu_index: int) -> plt.Figure:
311
+ """Generate a plot of GPU utilization over time
312
+
313
+ Args:
314
+ gpu_index: Index of the GPU to plot
315
+
316
+ Returns:
317
+ Matplotlib figure with utilization plot
318
+ """
319
+ plt.close('all') # Close all existing figures
320
+ fig, ax = plt.subplots(figsize=(10, 5))
321
+
322
+ if not self.has_nvidia_gpus or gpu_index not in self.history:
323
+ ax.set_title(f"No data available for GPU {gpu_index}")
324
+ return fig
325
+
326
+ history = self.history[gpu_index]
327
+ if not history['timestamps']:
328
+ ax.set_title(f"No history data for GPU {gpu_index}")
329
+ return fig
330
+
331
+ # Convert timestamps to strings
332
+ x = [t.strftime('%H:%M:%S') for t in history['timestamps']]
333
+
334
+ # If we have many points, show fewer labels for readability
335
+ if len(x) > 10:
336
+ step = len(x) // 10
337
+ ax.set_xticks(range(0, len(x), step))
338
+ ax.set_xticklabels([x[i] for i in range(0, len(x), step)], rotation=45)
339
+
340
+ # Plot utilization
341
+ ax.plot(x, list(history['utilization']), 'b-', label='GPU Utilization %')
342
+ ax.set_ylim(0, 100)
343
+
344
+ # Add temperature on secondary y-axis
345
+ ax2 = ax.twinx()
346
+ ax2.plot(x, list(history['temperature']), 'r-', label='Temperature °C')
347
+ ax2.set_ylabel('Temperature (°C)', color='r')
348
+ ax2.tick_params(axis='y', colors='r')
349
+
350
+ # Set labels and title
351
+ ax.set_title(f'GPU {gpu_index} Utilization Over Time')
352
+ ax.set_xlabel('Time')
353
+ ax.set_ylabel('Utilization %')
354
+ ax.grid(True, alpha=0.3)
355
+
356
+ # Add legend
357
+ lines, labels = ax.get_legend_handles_labels()
358
+ lines2, labels2 = ax2.get_legend_handles_labels()
359
+ ax.legend(lines + lines2, labels + labels2, loc='upper left')
360
+
361
+ plt.tight_layout()
362
+ return fig
363
+
364
+ def generate_memory_plot(self, gpu_index: int) -> plt.Figure:
365
+ """Generate a plot of GPU memory usage over time
366
+
367
+ Args:
368
+ gpu_index: Index of the GPU to plot
369
+
370
+ Returns:
371
+ Matplotlib figure with memory usage plot
372
+ """
373
+ plt.close('all') # Close all existing figures
374
+ fig, ax = plt.subplots(figsize=(10, 5))
375
+
376
+ if not self.has_nvidia_gpus or gpu_index not in self.history:
377
+ ax.set_title(f"No data available for GPU {gpu_index}")
378
+ return fig
379
+
380
+ history = self.history[gpu_index]
381
+ if not history['timestamps']:
382
+ ax.set_title(f"No history data for GPU {gpu_index}")
383
+ return fig
384
+
385
+ # Convert timestamps to strings
386
+ x = [t.strftime('%H:%M:%S') for t in history['timestamps']]
387
+
388
+ # If we have many points, show fewer labels for readability
389
+ if len(x) > 10:
390
+ step = len(x) // 10
391
+ ax.set_xticks(range(0, len(x), step))
392
+ ax.set_xticklabels([x[i] for i in range(0, len(x), step)], rotation=45)
393
+
394
+ # Plot memory percentage
395
+ ax.plot(x, list(history['memory_percent']), 'g-', label='Memory Usage %')
396
+ ax.set_ylim(0, 100)
397
+
398
+ # Add absolute memory values on secondary y-axis (convert to GB)
399
+ ax2 = ax.twinx()
400
+ memory_used_gb = [m / (1024**3) for m in history['memory_used']]
401
+ memory_total_gb = [m / (1024**3) for m in history['memory_total']]
402
+
403
+ ax2.plot(x, memory_used_gb, 'm--', label='Used (GB)')
404
+ ax2.set_ylabel('Memory (GB)')
405
+
406
+ # Set labels and title
407
+ ax.set_title(f'GPU {gpu_index} Memory Usage Over Time')
408
+ ax.set_xlabel('Time')
409
+ ax.set_ylabel('Usage %')
410
+ ax.grid(True, alpha=0.3)
411
+
412
+ # Add legend
413
+ lines, labels = ax.get_legend_handles_labels()
414
+ lines2, labels2 = ax2.get_legend_handles_labels()
415
+ ax.legend(lines + lines2, labels + labels2, loc='upper left')
416
+
417
+ plt.tight_layout()
418
+ return fig
419
+
420
+ def generate_power_plot(self, gpu_index: int) -> plt.Figure:
421
+ """Generate a plot of GPU power usage over time
422
+
423
+ Args:
424
+ gpu_index: Index of the GPU to plot
425
+
426
+ Returns:
427
+ Matplotlib figure with power usage plot
428
+ """
429
+ plt.close('all') # Close all existing figures
430
+ fig, ax = plt.subplots(figsize=(10, 5))
431
+
432
+ if not self.has_nvidia_gpus or gpu_index not in self.history:
433
+ ax.set_title(f"No data available for GPU {gpu_index}")
434
+ return fig
435
+
436
+ history = self.history[gpu_index]
437
+ if not history['timestamps'] or not any(history['power_usage']):
438
+ ax.set_title(f"No power data for GPU {gpu_index}")
439
+ return fig
440
+
441
+ # Convert timestamps to strings
442
+ x = [t.strftime('%H:%M:%S') for t in history['timestamps']]
443
+
444
+ # If we have many points, show fewer labels for readability
445
+ if len(x) > 10:
446
+ step = len(x) // 10
447
+ ax.set_xticks(range(0, len(x), step))
448
+ ax.set_xticklabels([x[i] for i in range(0, len(x), step)], rotation=45)
449
+
450
+ # Plot power usage
451
+ power_usage = list(history['power_usage'])
452
+ if any(power_usage): # Only plot if we have actual power data
453
+ ax.plot(x, power_usage, 'b-', label='Power Usage (W)')
454
+
455
+ # Get power limit if available
456
+ power_limit = list(history['power_limit'])
457
+ if any(power_limit): # Only plot if we have power limit data
458
+ # Show power limit as horizontal line
459
+ limit = max(power_limit) # Should be constant, but take max just in case
460
+ if limit > 0:
461
+ ax.axhline(y=limit, color='r', linestyle='--', label=f'Power Limit ({limit}W)')
462
+
463
+ # Set labels and title
464
+ ax.set_title(f'GPU {gpu_index} Power Usage Over Time')
465
+ ax.set_xlabel('Time')
466
+ ax.set_ylabel('Power (Watts)')
467
+ ax.grid(True, alpha=0.3)
468
+ ax.legend(loc='upper left')
469
+ else:
470
+ ax.set_title(f"Power data not available for GPU {gpu_index}")
471
+
472
+ plt.tight_layout()
473
+ return fig
474
+
475
+ def shutdown(self):
476
+ """Clean up resources when shutting down"""
477
+ self.stop_monitoring()
478
+
479
+ # Shutdown NVML if it was initialized
480
+ if PYNVML_AVAILABLE and self.has_nvidia_gpus:
481
+ try:
482
+ pynvml.nvmlShutdown()
483
+ logger.info("NVML shutdown complete")
484
+ except Exception as e:
485
+ logger.error(f"Error during NVML shutdown: {str(e)}")
vms/ui/monitoring/services/monitoring.py CHANGED
@@ -21,6 +21,8 @@ import matplotlib.pyplot as plt
21
 
22
  import numpy as np
23
 
 
 
24
  logger = logging.getLogger(__name__)
25
  logger.setLevel(logging.INFO)
26
 
@@ -51,6 +53,9 @@ class MonitoringService:
51
  # Per-core CPU history
52
  self.cpu_cores_percent = {}
53
 
 
 
 
54
  # Track if the monitoring thread is running
55
  self.is_running = False
56
  self.thread = None
@@ -124,6 +129,9 @@ class MonitoringService:
124
  return
125
 
126
  self.is_running = True
 
 
 
127
 
128
  def _monitor_loop():
129
  while self.is_running:
@@ -143,8 +151,12 @@ class MonitoringService:
143
  """Stop the monitoring thread"""
144
  if not self.is_running:
145
  return
146
-
147
  self.is_running = False
 
 
 
 
148
  if self.thread:
149
  self.thread.join(timeout=1.0)
150
  logger.info("System monitoring thread stopped")
 
21
 
22
  import numpy as np
23
 
24
+ from vms.ui.monitoring.services.gpu import GPUMonitoringService
25
+
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
28
 
 
53
  # Per-core CPU history
54
  self.cpu_cores_percent = {}
55
 
56
+ # Initialize GPU monitoring service
57
+ self.gpu = GPUMonitoringService(history_minutes=history_minutes, sample_interval=sample_interval)
58
+
59
  # Track if the monitoring thread is running
60
  self.is_running = False
61
  self.thread = None
 
129
  return
130
 
131
  self.is_running = True
132
+
133
+ # Start GPU monitoring if available
134
+ self.gpu.start_monitoring()
135
 
136
  def _monitor_loop():
137
  while self.is_running:
 
151
  """Stop the monitoring thread"""
152
  if not self.is_running:
153
  return
154
+
155
  self.is_running = False
156
+
157
+ # Stop GPU monitoring
158
+ self.gpu.stop_monitoring()
159
+
160
  if self.thread:
161
  self.thread.join(timeout=1.0)
162
  logger.info("System monitoring thread stopped")
vms/ui/monitoring/tabs/gpu_tab.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU monitoring tab for Video Model Studio UI.
3
+ Displays detailed GPU metrics and visualizations.
4
+ """
5
+
6
+ import gradio as gr
7
+ import time
8
+ import logging
9
+ from pathlib import Path
10
+ import os
11
+ from typing import Dict, Any, List, Optional, Tuple
12
+ from datetime import datetime, timedelta
13
+
14
+ from vms.utils.base_tab import BaseTab
15
+ from vms.ui.monitoring.utils import human_readable_size
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(logging.INFO)
19
+
20
+ class GPUTab(BaseTab):
21
+ """Tab for GPU-specific monitoring and statistics"""
22
+
23
+ def __init__(self, app_state):
24
+ super().__init__(app_state)
25
+ self.id = "GPU_tab"
26
+ self.title = "GPU Stats"
27
+ self.refresh_interval = 5
28
+ self.selected_gpu = 0
29
+
30
+ def create(self, parent=None) -> gr.TabItem:
31
+ """Create the GPU tab UI components"""
32
+ with gr.TabItem(self.title, id=self.id) as tab:
33
+ with gr.Row():
34
+ gr.Markdown("## 🖥️ GPU Monitoring")
35
+
36
+ # No GPUs available message (hidden by default)
37
+ with gr.Row(visible=not self.app.monitoring.gpu.has_nvidia_gpus):
38
+ with gr.Column():
39
+ gr.Markdown("### No NVIDIA GPUs detected")
40
+ gr.Markdown("GPU monitoring is only available for NVIDIA GPUs. If you have NVIDIA GPUs installed, ensure the drivers are properly configured.")
41
+
42
+ # GPU content (only visible if GPUs are available)
43
+ with gr.Row(visible=self.app.monitoring.gpu.has_nvidia_gpus):
44
+ # GPU selector if multiple GPUs
45
+ if self.app.monitoring.gpu.gpu_count > 1:
46
+ with gr.Column(scale=1):
47
+ gpu_options = [f"GPU {i}" for i in range(self.app.monitoring.gpu.gpu_count)]
48
+ self.components["gpu_selector"] = gr.Dropdown(
49
+ choices=gpu_options,
50
+ value=gpu_options[0] if gpu_options else None,
51
+ label="Select GPU",
52
+ interactive=True
53
+ )
54
+
55
+ # Current metrics
56
+ with gr.Column(scale=3):
57
+ self.components["current_metrics"] = gr.Markdown("Loading GPU metrics...")
58
+
59
+ # Display GPU metrics in tabs
60
+ with gr.Tabs(visible=self.app.monitoring.gpu.has_nvidia_gpus) as metrics_tabs:
61
+ with gr.Tab(label="Utilization") as util_tab:
62
+ self.components["utilization_plot"] = gr.Plot()
63
+
64
+ with gr.Tab(label="Memory") as memory_tab:
65
+ self.components["memory_plot"] = gr.Plot()
66
+
67
+ with gr.Tab(label="Power") as power_tab:
68
+ self.components["power_plot"] = gr.Plot()
69
+
70
+ # Process information
71
+ with gr.Row(visible=self.app.monitoring.gpu.has_nvidia_gpus):
72
+ with gr.Column():
73
+ gr.Markdown("### Active Processes")
74
+ self.components["process_info"] = gr.Markdown("Loading process information...")
75
+
76
+ # GPU information summary
77
+ with gr.Row(visible=self.app.monitoring.gpu.has_nvidia_gpus):
78
+ with gr.Column():
79
+ gr.Markdown("### GPU Information")
80
+ self.components["gpu_info"] = gr.Markdown("Loading GPU information...")
81
+
82
+ # Toggle for enabling/disabling auto-refresh
83
+ with gr.Row():
84
+ self.components["auto_refresh"] = gr.Checkbox(
85
+ label=f"Auto refresh (every {self.refresh_interval} seconds)",
86
+ value=True,
87
+ info="Automatically refresh GPU metrics"
88
+ )
89
+ self.components["refresh_btn"] = gr.Button("Refresh Now")
90
+
91
+ # Timer for auto-refresh
92
+ self.components["refresh_timer"] = gr.Timer(
93
+ value=self.refresh_interval
94
+ )
95
+
96
+ return tab
97
+
98
+ def connect_events(self) -> None:
99
+ """Connect event handlers to UI components"""
100
+ # GPU selector (if multiple GPUs)
101
+ if self.app.monitoring.gpu.gpu_count > 1 and "gpu_selector" in self.components:
102
+ self.components["gpu_selector"].change(
103
+ fn=self.update_selected_gpu,
104
+ inputs=[self.components["gpu_selector"]],
105
+ outputs=[
106
+ self.components["current_metrics"],
107
+ self.components["utilization_plot"],
108
+ self.components["memory_plot"],
109
+ self.components["power_plot"],
110
+ self.components["process_info"],
111
+ self.components["gpu_info"]
112
+ ]
113
+ )
114
+
115
+ # Manual refresh button
116
+ self.components["refresh_btn"].click(
117
+ fn=self.refresh_all,
118
+ outputs=[
119
+ self.components["current_metrics"],
120
+ self.components["utilization_plot"],
121
+ self.components["memory_plot"],
122
+ self.components["power_plot"],
123
+ self.components["process_info"],
124
+ self.components["gpu_info"]
125
+ ]
126
+ )
127
+
128
+ # Auto-refresh timer
129
+ self.components["refresh_timer"].tick(
130
+ fn=self.conditional_refresh,
131
+ inputs=[self.components["auto_refresh"]],
132
+ outputs=[
133
+ self.components["current_metrics"],
134
+ self.components["utilization_plot"],
135
+ self.components["memory_plot"],
136
+ self.components["power_plot"],
137
+ self.components["process_info"],
138
+ self.components["gpu_info"]
139
+ ]
140
+ )
141
+
142
+ def on_enter(self):
143
+ """Called when the tab is selected"""
144
+ # Trigger initial refresh
145
+ return self.refresh_all()
146
+
147
+ def update_selected_gpu(self, gpu_selector: str) -> Tuple:
148
+ """Update the selected GPU and refresh data
149
+
150
+ Args:
151
+ gpu_selector: Selected GPU string ("GPU X")
152
+
153
+ Returns:
154
+ Updated components
155
+ """
156
+ # Extract GPU index from selector string
157
+ try:
158
+ self.selected_gpu = int(gpu_selector.replace("GPU ", ""))
159
+ except (ValueError, AttributeError):
160
+ self.selected_gpu = 0
161
+
162
+ # Refresh all components with the new selected GPU
163
+ return self.refresh_all()
164
+
165
+ def conditional_refresh(self, auto_refresh: bool) -> Tuple:
166
+ """Only refresh if auto-refresh is enabled
167
+
168
+ Args:
169
+ auto_refresh: Whether auto-refresh is enabled
170
+
171
+ Returns:
172
+ Updated components or unchanged components
173
+ """
174
+ if auto_refresh:
175
+ return self.refresh_all()
176
+
177
+ # Return current values unchanged if auto-refresh is disabled
178
+ return (
179
+ self.components["current_metrics"].value,
180
+ self.components["utilization_plot"].value,
181
+ self.components["memory_plot"].value,
182
+ self.components["power_plot"].value,
183
+ self.components["process_info"].value,
184
+ self.components["gpu_info"].value
185
+ )
186
+
187
+ def refresh_all(self) -> Tuple:
188
+ """Refresh all GPU monitoring components
189
+
190
+ Returns:
191
+ Updated values for all components
192
+ """
193
+ try:
194
+ if not self.app.monitoring.gpu.has_nvidia_gpus:
195
+ return (
196
+ "No NVIDIA GPUs detected",
197
+ None,
198
+ None,
199
+ None,
200
+ "No process information available",
201
+ "No GPU information available"
202
+ )
203
+
204
+ # Get current metrics for the selected GPU
205
+ all_metrics = self.app.monitoring.gpu.get_current_metrics()
206
+ if not all_metrics or self.selected_gpu >= len(all_metrics):
207
+ return (
208
+ "GPU metrics not available",
209
+ None,
210
+ None,
211
+ None,
212
+ "No process information available",
213
+ "No GPU information available"
214
+ )
215
+
216
+ # Get selected GPU metrics
217
+ gpu_metrics = all_metrics[self.selected_gpu]
218
+
219
+ # Format current metrics as markdown
220
+ metrics_html = self.format_current_metrics(gpu_metrics)
221
+
222
+ # Format process information
223
+ process_info_html = self.format_process_info(gpu_metrics)
224
+
225
+ # Format GPU information
226
+ gpu_info = self.app.monitoring.gpu.get_gpu_info()
227
+ gpu_info_html = self.format_gpu_info(gpu_info[self.selected_gpu] if self.selected_gpu < len(gpu_info) else {})
228
+
229
+ # Generate plots
230
+ utilization_plot = self.app.monitoring.gpu.generate_utilization_plot(self.selected_gpu)
231
+ memory_plot = self.app.monitoring.gpu.generate_memory_plot(self.selected_gpu)
232
+ power_plot = self.app.monitoring.gpu.generate_power_plot(self.selected_gpu)
233
+
234
+ return (
235
+ metrics_html,
236
+ utilization_plot,
237
+ memory_plot,
238
+ power_plot,
239
+ process_info_html,
240
+ gpu_info_html
241
+ )
242
+
243
+ except Exception as e:
244
+ logger.error(f"Error refreshing GPU data: {str(e)}", exc_info=True)
245
+ error_msg = f"Error retrieving GPU data: {str(e)}"
246
+ return (
247
+ error_msg,
248
+ None,
249
+ None,
250
+ None,
251
+ error_msg,
252
+ error_msg
253
+ )
254
+
255
+ def format_current_metrics(self, metrics: Dict[str, Any]) -> str:
256
+ """Format current GPU metrics as HTML/Markdown
257
+
258
+ Args:
259
+ metrics: Current metrics dictionary
260
+
261
+ Returns:
262
+ Formatted HTML/Markdown string
263
+ """
264
+ if 'error' in metrics:
265
+ return f"Error retrieving GPU metrics: {metrics['error']}"
266
+
267
+ # Format timestamp
268
+ if isinstance(metrics.get('timestamp'), datetime):
269
+ timestamp_str = metrics['timestamp'].strftime('%Y-%m-%d %H:%M:%S')
270
+ else:
271
+ timestamp_str = "Unknown"
272
+
273
+ # Style for GPU utilization
274
+ util_style = "color: green;"
275
+ if metrics.get('utilization_gpu', 0) > 90:
276
+ util_style = "color: red; font-weight: bold;"
277
+ elif metrics.get('utilization_gpu', 0) > 70:
278
+ util_style = "color: orange;"
279
+
280
+ # Style for memory usage
281
+ mem_style = "color: green;"
282
+ if metrics.get('memory_percent', 0) > 90:
283
+ mem_style = "color: red; font-weight: bold;"
284
+ elif metrics.get('memory_percent', 0) > 70:
285
+ mem_style = "color: orange;"
286
+
287
+ # Style for temperature
288
+ temp_style = "color: green;"
289
+ temp = metrics.get('temperature', 0)
290
+ if temp > 85:
291
+ temp_style = "color: red; font-weight: bold;"
292
+ elif temp > 75:
293
+ temp_style = "color: orange;"
294
+
295
+ # Memory usage in GB
296
+ memory_used_gb = metrics.get('memory_used', 0) / (1024**3)
297
+ memory_total_gb = metrics.get('memory_total', 0) / (1024**3)
298
+
299
+ # Power usage and limit
300
+ power_html = ""
301
+ if metrics.get('power_usage') is not None:
302
+ power_html = f"**Power Usage:** {metrics['power_usage']:.1f}W\n"
303
+
304
+ html = f"""
305
+ ### Current Status as of {timestamp_str}
306
+
307
+ **GPU Utilization:** <span style="{util_style}">{metrics.get('utilization_gpu', 0):.1f}%</span>
308
+ **Memory Usage:** <span style="{mem_style}">{metrics.get('memory_percent', 0):.1f}% ({memory_used_gb:.2f}/{memory_total_gb:.2f} GB)</span>
309
+ **Temperature:** <span style="{temp_style}">{metrics.get('temperature', 0)}°C</span>
310
+ {power_html}
311
+ """
312
+ return html
313
+ def format_process_info(self, metrics: Dict[str, Any]) -> str:
314
+ """Format GPU process information as HTML/Markdown
315
+
316
+ Args:
317
+ metrics: Current metrics dictionary with process information
318
+
319
+ Returns:
320
+ Formatted HTML/Markdown string
321
+ """
322
+ if 'error' in metrics:
323
+ return "Process information not available"
324
+
325
+ processes = metrics.get('processes', [])
326
+ if not processes:
327
+ return "No active processes using this GPU"
328
+
329
+ # Sort processes by memory usage (descending)
330
+ sorted_processes = sorted(processes, key=lambda p: p.get('memory_used', 0), reverse=True)
331
+
332
+ html = "| PID | Process Name | Memory Usage |\n"
333
+ html += "|-----|-------------|-------------|\n"
334
+
335
+ for proc in sorted_processes:
336
+ pid = proc.get('pid', 'Unknown')
337
+ name = proc.get('name', 'Unknown')
338
+ mem_mb = proc.get('memory_used', 0) / (1024**2) # Convert to MB
339
+
340
+ html += f"| {pid} | {name} | {mem_mb:.1f} MB |\n"
341
+
342
+ return html
343
+
344
+ def format_gpu_info(self, info: Dict[str, Any]) -> str:
345
+ """Format GPU information as HTML/Markdown
346
+
347
+ Args:
348
+ info: GPU information dictionary
349
+
350
+ Returns:
351
+ Formatted HTML/Markdown string
352
+ """
353
+ if 'error' in info:
354
+ return f"GPU information not available: {info.get('error', 'Unknown error')}"
355
+
356
+ # Format memory in GB
357
+ memory_total_gb = info.get('memory_total', 0) / (1024**3)
358
+
359
+ html = f"""
360
+ **Name:** {info.get('name', 'Unknown')}
361
+ **Memory:** {memory_total_gb:.2f} GB
362
+ **UUID:** {info.get('uuid', 'N/A')}
363
+ **Compute Capability:** {info.get('compute_capability', 'N/A')}
364
+ """
365
+
366
+ # Add power limit if available
367
+ if info.get('power_limit') is not None:
368
+ html += f"**Power Limit:** {info['power_limit']:.1f}W\n"
369
+
370
+ return html
vms/ui/project/services/training.py CHANGED
@@ -38,7 +38,8 @@ from vms.config import (
38
  DEFAULT_MAX_GPUS,
39
  DEFAULT_PRECOMPUTATION_ITEMS,
40
  DEFAULT_NB_TRAINING_STEPS,
41
- DEFAULT_NB_LR_WARMUP_STEPS
 
42
  )
43
  from vms.utils import (
44
  get_available_gpu_count,
@@ -151,7 +152,8 @@ class TrainingService:
151
  "training_preset": list(TRAINING_PRESETS.keys())[0],
152
  "num_gpus": DEFAULT_NUM_GPUS,
153
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
154
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
 
155
  }
156
 
157
  # Copy default values first
@@ -231,7 +233,8 @@ class TrainingService:
231
  "training_preset": list(TRAINING_PRESETS.keys())[0],
232
  "num_gpus": DEFAULT_NUM_GPUS,
233
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
234
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
 
235
  }
236
 
237
  # Use lock for reading too to avoid reading during a write
@@ -369,6 +372,7 @@ class TrainingService:
369
  # Default state with all required values
370
  default_state = {
371
  "model_type": list(MODEL_TYPES.keys())[0],
 
372
  "training_type": list(TRAINING_TYPES.keys())[0],
373
  "lora_rank": DEFAULT_LORA_RANK_STR,
374
  "lora_alpha": DEFAULT_LORA_ALPHA_STR,
@@ -379,7 +383,8 @@ class TrainingService:
379
  "training_preset": list(TRAINING_PRESETS.keys())[0],
380
  "num_gpus": DEFAULT_NUM_GPUS,
381
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
382
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
 
383
  }
384
 
385
  # If file doesn't exist, create it with default values
@@ -1144,12 +1149,15 @@ class TrainingService:
1144
  "batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
1145
  "learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
1146
  "save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1147
- "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
 
1148
  })
1149
 
1150
  # Check if we should auto-recover (immediate restart)
1151
- auto_recover = True # Always auto-recover on startup
1152
-
 
 
1153
  if auto_recover:
1154
  try:
1155
  result = self.start_training(
 
38
  DEFAULT_MAX_GPUS,
39
  DEFAULT_PRECOMPUTATION_ITEMS,
40
  DEFAULT_NB_TRAINING_STEPS,
41
+ DEFAULT_NB_LR_WARMUP_STEPS,
42
+ DEFAULT_AUTO_RESUME
43
  )
44
  from vms.utils import (
45
  get_available_gpu_count,
 
152
  "training_preset": list(TRAINING_PRESETS.keys())[0],
153
  "num_gpus": DEFAULT_NUM_GPUS,
154
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
155
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
156
+ "auto_resume": False
157
  }
158
 
159
  # Copy default values first
 
233
  "training_preset": list(TRAINING_PRESETS.keys())[0],
234
  "num_gpus": DEFAULT_NUM_GPUS,
235
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
236
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
237
+ "auto_resume": DEFAULT_AUTO_RESUME
238
  }
239
 
240
  # Use lock for reading too to avoid reading during a write
 
372
  # Default state with all required values
373
  default_state = {
374
  "model_type": list(MODEL_TYPES.keys())[0],
375
+ "model_version": "",
376
  "training_type": list(TRAINING_TYPES.keys())[0],
377
  "lora_rank": DEFAULT_LORA_RANK_STR,
378
  "lora_alpha": DEFAULT_LORA_ALPHA_STR,
 
383
  "training_preset": list(TRAINING_PRESETS.keys())[0],
384
  "num_gpus": DEFAULT_NUM_GPUS,
385
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
386
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
387
+ "auto_resume": False
388
  }
389
 
390
  # If file doesn't exist, create it with default values
 
1149
  "batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
1150
  "learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
1151
  "save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1152
+ "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
1153
+ "auto_resume_checkbox": ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
1154
  })
1155
 
1156
  # Check if we should auto-recover (immediate restart)
1157
+ ui_state = self.load_ui_state()
1158
+ auto_recover = ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
1159
+ logger.info(f"Auto-resume is {'enabled' if auto_recover else 'disabled'}")
1160
+
1161
  if auto_recover:
1162
  try:
1163
  result = self.start_training(
vms/ui/project/tabs/train_tab.py CHANGED
@@ -26,6 +26,7 @@ from vms.config import (
26
  DEFAULT_PRECOMPUTATION_ITEMS,
27
  DEFAULT_NB_TRAINING_STEPS,
28
  DEFAULT_NB_LR_WARMUP_STEPS,
 
29
  )
30
 
31
  logger = logging.getLogger(__name__)
@@ -231,6 +232,13 @@ class TrainTab(BaseTab):
231
  interactive=has_checkpoints
232
  )
233
 
 
 
 
 
 
 
 
234
  with gr.Row():
235
  with gr.Column():
236
  self.components["status_box"] = gr.Textbox(
@@ -381,6 +389,12 @@ class TrainTab(BaseTab):
381
  ]
382
  )
383
 
 
 
 
 
 
 
384
  # Add in the connect_events() method:
385
  self.components["num_gpus"].change(
386
  fn=lambda v: self.app.update_ui_state(num_gpus=v),
 
26
  DEFAULT_PRECOMPUTATION_ITEMS,
27
  DEFAULT_NB_TRAINING_STEPS,
28
  DEFAULT_NB_LR_WARMUP_STEPS,
29
+ DEFAULT_AUTO_RESUME
30
  )
31
 
32
  logger = logging.getLogger(__name__)
 
232
  interactive=has_checkpoints
233
  )
234
 
235
+ with gr.Row():
236
+ self.components["auto_resume_checkbox"] = gr.Checkbox(
237
+ label="Automatically continue training in case of server reboot.",
238
+ value=DEFAULT_AUTO_RESUME,
239
+ info="When enabled, training will automatically resume from the latest checkpoint after app restart"
240
+ )
241
+
242
  with gr.Row():
243
  with gr.Column():
244
  self.components["status_box"] = gr.Textbox(
 
389
  ]
390
  )
391
 
392
+ self.components["auto_resume_checkbox"].change(
393
+ fn=lambda v: self.app.update_ui_state(auto_resume=v),
394
+ inputs=[self.components["auto_resume_checkbox"]],
395
+ outputs=[]
396
+ )
397
+
398
  # Add in the connect_events() method:
399
  self.components["num_gpus"].change(
400
  fn=lambda v: self.app.update_ui_state(num_gpus=v),