Tonic commited on
Commit
0695ae5
·
1 Parent(s): df524ae

move all model and components to cuda

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -352,11 +352,30 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
352
  # Move model to evaluation mode
353
  pipe.model.eval()
354
 
355
- # Move the entire model to GPU
356
  pipe.model = pipe.model.to(device)
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  # Use predict_quantiles with proper formatting
359
  with torch.amp.autocast('cuda'):
 
 
360
  quantiles, mean = pipe.predict_quantiles(
361
  context=context,
362
  prediction_length=actual_prediction_length,
 
352
  # Move model to evaluation mode
353
  pipe.model.eval()
354
 
355
+ # Move the entire model and all its components to GPU
356
  pipe.model = pipe.model.to(device)
357
 
358
+ # Ensure all model parameters and buffers are on GPU
359
+ for param in pipe.model.parameters():
360
+ param.data = param.data.to(device)
361
+ for buffer in pipe.model.buffers():
362
+ buffer.data = buffer.data.to(device)
363
+
364
+ # Move any registered buffers or parameters in submodules
365
+ for module in pipe.model.modules():
366
+ if hasattr(module, 'register_buffer'):
367
+ for name, buffer in module._buffers.items():
368
+ if buffer is not None:
369
+ module._buffers[name] = buffer.to(device)
370
+ if hasattr(module, 'register_parameter'):
371
+ for name, param in module._parameters.items():
372
+ if param is not None:
373
+ module._parameters[name] = param.to(device)
374
+
375
  # Use predict_quantiles with proper formatting
376
  with torch.amp.autocast('cuda'):
377
+ # Ensure all inputs are on GPU
378
+ context = context.to(device)
379
  quantiles, mean = pipe.predict_quantiles(
380
  context=context,
381
  prediction_length=actual_prediction_length,