Spaces:
Running
on
Zero
Running
on
Zero
add rolling window predictions extension
Browse files
app.py
CHANGED
@@ -533,11 +533,59 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
533 |
|
534 |
# If we had to limit the prediction length, extend the prediction
|
535 |
if actual_prediction_length < prediction_days:
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
except Exception as e:
|
543 |
print(f"Chronos prediction error: {str(e)}")
|
|
|
533 |
|
534 |
# If we had to limit the prediction length, extend the prediction
|
535 |
if actual_prediction_length < prediction_days:
|
536 |
+
# Initialize arrays for extended predictions
|
537 |
+
extended_mean_pred = mean_pred.copy()
|
538 |
+
extended_std_pred = std_pred.copy()
|
539 |
+
|
540 |
+
# Calculate the number of extension steps needed
|
541 |
+
remaining_days = prediction_days - actual_prediction_length
|
542 |
+
steps_needed = (remaining_days + actual_prediction_length - 1) // actual_prediction_length
|
543 |
+
|
544 |
+
for step in range(steps_needed):
|
545 |
+
# Use the last window_size points as context for next prediction
|
546 |
+
window_size = min(64, len(extended_mean_pred))
|
547 |
+
context_window = extended_mean_pred[-window_size:]
|
548 |
+
|
549 |
+
# Normalize the context window
|
550 |
+
normalized_context = scaler.fit_transform(context_window.reshape(-1, 1)).flatten()
|
551 |
+
|
552 |
+
# Convert to tensor and ensure proper shape
|
553 |
+
context = torch.tensor(normalized_context, dtype=dtype, device=device)
|
554 |
+
if len(context.shape) == 1:
|
555 |
+
context = context.unsqueeze(0)
|
556 |
+
|
557 |
+
# Make prediction for next window
|
558 |
+
with torch.amp.autocast('cuda'):
|
559 |
+
next_quantiles, next_mean = pipe.predict_quantiles(
|
560 |
+
context=context,
|
561 |
+
prediction_length=min(actual_prediction_length, remaining_days),
|
562 |
+
quantile_levels=[0.1, 0.5, 0.9]
|
563 |
+
)
|
564 |
+
|
565 |
+
# Convert predictions to numpy and denormalize
|
566 |
+
next_mean = next_mean.detach().cpu().numpy()
|
567 |
+
next_quantiles = next_quantiles.detach().cpu().numpy()
|
568 |
+
|
569 |
+
# Denormalize predictions
|
570 |
+
next_mean_pred = scaler.inverse_transform(next_mean.reshape(-1, 1)).flatten()
|
571 |
+
next_lower = scaler.inverse_transform(next_quantiles[0, :, 0].reshape(-1, 1)).flatten()
|
572 |
+
next_upper = scaler.inverse_transform(next_quantiles[0, :, 2].reshape(-1, 1)).flatten()
|
573 |
+
|
574 |
+
# Calculate standard deviation
|
575 |
+
next_std_pred = (next_upper - next_lower) / (2 * 1.645)
|
576 |
+
|
577 |
+
# Append predictions
|
578 |
+
extended_mean_pred = np.concatenate([extended_mean_pred, next_mean_pred])
|
579 |
+
extended_std_pred = np.concatenate([extended_std_pred, next_std_pred])
|
580 |
+
|
581 |
+
# Update remaining days
|
582 |
+
remaining_days -= len(next_mean_pred)
|
583 |
+
if remaining_days <= 0:
|
584 |
+
break
|
585 |
+
|
586 |
+
# Trim to exact prediction length if needed
|
587 |
+
mean_pred = extended_mean_pred[:prediction_days]
|
588 |
+
std_pred = extended_std_pred[:prediction_days]
|
589 |
|
590 |
except Exception as e:
|
591 |
print(f"Chronos prediction error: {str(e)}")
|