Spaces:
Running
on
Zero
Running
on
Zero
correctly pass tensor shapes, adds validation checks adds dtype matching and validation
Browse files
app.py
CHANGED
@@ -360,8 +360,15 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
360 |
# First try with predict_quantiles
|
361 |
try:
|
362 |
print("Trying predict_quantiles...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
quantiles, mean = pipe.predict_quantiles(
|
364 |
-
context=
|
365 |
prediction_length=actual_prediction_length,
|
366 |
quantile_levels=[0.1, 0.5, 0.9] # 10th, 50th, and 90th percentiles
|
367 |
)
|
@@ -388,8 +395,11 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
388 |
print("Falling back to predict...")
|
389 |
|
390 |
# Fallback to predict if predict_quantiles fails
|
|
|
|
|
|
|
391 |
prediction = pipe.predict(
|
392 |
-
context=
|
393 |
prediction_length=actual_prediction_length,
|
394 |
num_samples=100
|
395 |
)
|
|
|
360 |
# First try with predict_quantiles
|
361 |
try:
|
362 |
print("Trying predict_quantiles...")
|
363 |
+
# Convert context to the correct format
|
364 |
+
context_tensor = context.to(device=pipe.model.device, dtype=torch.float16)
|
365 |
+
|
366 |
+
# Ensure context is properly formatted
|
367 |
+
if len(context_tensor.shape) != 3:
|
368 |
+
raise ValueError(f"Expected 3D tensor, got shape {context_tensor.shape}")
|
369 |
+
|
370 |
quantiles, mean = pipe.predict_quantiles(
|
371 |
+
context=context_tensor,
|
372 |
prediction_length=actual_prediction_length,
|
373 |
quantile_levels=[0.1, 0.5, 0.9] # 10th, 50th, and 90th percentiles
|
374 |
)
|
|
|
395 |
print("Falling back to predict...")
|
396 |
|
397 |
# Fallback to predict if predict_quantiles fails
|
398 |
+
# Convert context to the correct format
|
399 |
+
context_tensor = context.to(device=pipe.model.device, dtype=torch.float16)
|
400 |
+
|
401 |
prediction = pipe.predict(
|
402 |
+
context=context_tensor,
|
403 |
prediction_length=actual_prediction_length,
|
404 |
num_samples=100
|
405 |
)
|