Tonic commited on
Commit
cd87fec
·
1 Parent(s): 61754ec

move all model and components to cuda

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -310,8 +310,25 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
310
 
311
  # Load pipeline and move to GPU
312
  pipe = load_pipeline()
 
 
313
  pipe.model = pipe.model.cuda()
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  # Get the model's device and dtype
316
  device = next(pipe.model.parameters()).device
317
  dtype = next(pipe.model.parameters()).dtype
@@ -350,14 +367,8 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
350
  print(f"Model device: {next(pipe.model.parameters()).device}")
351
  print(f"Model dtype: {next(pipe.model.parameters()).dtype}")
352
 
353
- # Move model to evaluation mode and ensure it's on the correct device
354
  pipe.model.eval()
355
- pipe.model = pipe.model.to(device)
356
-
357
- # Ensure all model components are on the same device
358
- for module in pipe.model.modules():
359
- if hasattr(module, 'to'):
360
- module.to(device)
361
 
362
  # Use predict_quantiles with proper formatting
363
  quantiles, mean = pipe.predict_quantiles(
 
310
 
311
  # Load pipeline and move to GPU
312
  pipe = load_pipeline()
313
+
314
+ # Move entire model and its components to CUDA
315
  pipe.model = pipe.model.cuda()
316
 
317
+ # Move all model parameters and buffers to CUDA
318
+ for param in pipe.model.parameters():
319
+ param.data = param.data.cuda()
320
+ for buffer in pipe.model.buffers():
321
+ buffer.data = buffer.data.cuda()
322
+
323
+ # Move all submodules to CUDA
324
+ for module in pipe.model.modules():
325
+ module.to('cuda')
326
+ # Move any internal states or buffers
327
+ if hasattr(module, '_buffers'):
328
+ for buffer in module._buffers.values():
329
+ if buffer is not None:
330
+ buffer.data = buffer.data.cuda()
331
+
332
  # Get the model's device and dtype
333
  device = next(pipe.model.parameters()).device
334
  dtype = next(pipe.model.parameters()).dtype
 
367
  print(f"Model device: {next(pipe.model.parameters()).device}")
368
  print(f"Model dtype: {next(pipe.model.parameters()).dtype}")
369
 
370
+ # Move model to evaluation mode
371
  pipe.model.eval()
 
 
 
 
 
 
372
 
373
  # Use predict_quantiles with proper formatting
374
  quantiles, mean = pipe.predict_quantiles(