Spaces:
Running
on
Zero
Running
on
Zero
attempt to solve gpu error
Browse files
app.py
CHANGED
@@ -366,11 +366,11 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
366 |
for module in pipe.model.modules():
|
367 |
if hasattr(module, 'register_buffer'):
|
368 |
for name, buffer in module._buffers.items():
|
369 |
-
if buffer is not None:
|
370 |
module._buffers[name] = buffer.to(device)
|
371 |
if hasattr(module, 'register_parameter'):
|
372 |
for name, param in module._parameters.items():
|
373 |
-
if param is not None:
|
374 |
module._parameters[name] = param.to(device)
|
375 |
|
376 |
# Use predict_quantiles with proper formatting
|
@@ -386,7 +386,8 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
386 |
|
387 |
# Force all model components to GPU
|
388 |
pipe.model = pipe.model.to(device)
|
389 |
-
pipe
|
|
|
390 |
|
391 |
# Ensure all model states are on GPU
|
392 |
if hasattr(pipe.model, 'state_dict'):
|
@@ -399,7 +400,7 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
399 |
# Ensure all model attributes are on GPU
|
400 |
for attr_name in dir(pipe.model):
|
401 |
attr = getattr(pipe.model, attr_name)
|
402 |
-
if isinstance(attr, torch.Tensor):
|
403 |
setattr(pipe.model, attr_name, attr.to(device))
|
404 |
|
405 |
# Ensure all model submodules are on GPU
|
@@ -409,17 +410,17 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
409 |
|
410 |
# Ensure all model buffers are on GPU
|
411 |
for name, buffer in pipe.model.named_buffers():
|
412 |
-
if buffer is not None:
|
413 |
pipe.model.register_buffer(name, buffer.to(device))
|
414 |
|
415 |
# Ensure all model parameters are on GPU
|
416 |
for name, param in pipe.model.named_parameters():
|
417 |
-
if param is not None:
|
418 |
param.data = param.data.to(device)
|
419 |
|
420 |
# Ensure all model attributes that might contain tensors are on GPU
|
421 |
for name, value in pipe.model.__dict__.items():
|
422 |
-
if isinstance(value, torch.Tensor):
|
423 |
pipe.model.__dict__[name] = value.to(device)
|
424 |
|
425 |
quantiles, mean = pipe.predict_quantiles(
|
|
|
366 |
for module in pipe.model.modules():
|
367 |
if hasattr(module, 'register_buffer'):
|
368 |
for name, buffer in module._buffers.items():
|
369 |
+
if buffer is not None and hasattr(buffer, 'to'):
|
370 |
module._buffers[name] = buffer.to(device)
|
371 |
if hasattr(module, 'register_parameter'):
|
372 |
for name, param in module._parameters.items():
|
373 |
+
if param is not None and hasattr(param, 'to'):
|
374 |
module._parameters[name] = param.to(device)
|
375 |
|
376 |
# Use predict_quantiles with proper formatting
|
|
|
386 |
|
387 |
# Force all model components to GPU
|
388 |
pipe.model = pipe.model.to(device)
|
389 |
+
if hasattr(pipe, 'tokenizer') and hasattr(pipe.tokenizer, 'to'):
|
390 |
+
pipe.tokenizer = pipe.tokenizer.to(device)
|
391 |
|
392 |
# Ensure all model states are on GPU
|
393 |
if hasattr(pipe.model, 'state_dict'):
|
|
|
400 |
# Ensure all model attributes are on GPU
|
401 |
for attr_name in dir(pipe.model):
|
402 |
attr = getattr(pipe.model, attr_name)
|
403 |
+
if isinstance(attr, torch.Tensor) and hasattr(attr, 'to'):
|
404 |
setattr(pipe.model, attr_name, attr.to(device))
|
405 |
|
406 |
# Ensure all model submodules are on GPU
|
|
|
410 |
|
411 |
# Ensure all model buffers are on GPU
|
412 |
for name, buffer in pipe.model.named_buffers():
|
413 |
+
if buffer is not None and hasattr(buffer, 'to'):
|
414 |
pipe.model.register_buffer(name, buffer.to(device))
|
415 |
|
416 |
# Ensure all model parameters are on GPU
|
417 |
for name, param in pipe.model.named_parameters():
|
418 |
+
if param is not None and hasattr(param, 'to'):
|
419 |
param.data = param.data.to(device)
|
420 |
|
421 |
# Ensure all model attributes that might contain tensors are on GPU
|
422 |
for name, value in pipe.model.__dict__.items():
|
423 |
+
if isinstance(value, torch.Tensor) and hasattr(value, 'to'):
|
424 |
pipe.model.__dict__[name] = value.to(device)
|
425 |
|
426 |
quantiles, mean = pipe.predict_quantiles(
|