Tonic commited on
Commit
17ea549
·
1 Parent(s): 0bddd1c

add rolling window predictions extension

Browse files
Files changed (1) hide show
  1. app.py +53 -5
app.py CHANGED
@@ -533,11 +533,59 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
533
 
534
  # If we had to limit the prediction length, extend the prediction
535
  if actual_prediction_length < prediction_days:
536
- last_pred = mean_pred[-1]
537
- last_std = std_pred[-1]
538
- extension = np.array([last_pred * (1 + np.random.normal(0, last_std, prediction_days - actual_prediction_length))])
539
- mean_pred = np.concatenate([mean_pred, extension])
540
- std_pred = np.concatenate([std_pred, np.full(prediction_days - actual_prediction_length, last_std)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
 
542
  except Exception as e:
543
  print(f"Chronos prediction error: {str(e)}")
 
533
 
534
  # If we had to limit the prediction length, extend the prediction
535
  if actual_prediction_length < prediction_days:
536
+ # Initialize arrays for extended predictions
537
+ extended_mean_pred = mean_pred.copy()
538
+ extended_std_pred = std_pred.copy()
539
+
540
+ # Calculate the number of extension steps needed
541
+ remaining_days = prediction_days - actual_prediction_length
542
+ steps_needed = (remaining_days + actual_prediction_length - 1) // actual_prediction_length
543
+
544
+ for step in range(steps_needed):
545
+ # Use the last window_size points as context for next prediction
546
+ window_size = min(64, len(extended_mean_pred))
547
+ context_window = extended_mean_pred[-window_size:]
548
+
549
+ # Normalize the context window
550
+ normalized_context = scaler.fit_transform(context_window.reshape(-1, 1)).flatten()
551
+
552
+ # Convert to tensor and ensure proper shape
553
+ context = torch.tensor(normalized_context, dtype=dtype, device=device)
554
+ if len(context.shape) == 1:
555
+ context = context.unsqueeze(0)
556
+
557
+ # Make prediction for next window
558
+ with torch.amp.autocast('cuda'):
559
+ next_quantiles, next_mean = pipe.predict_quantiles(
560
+ context=context,
561
+ prediction_length=min(actual_prediction_length, remaining_days),
562
+ quantile_levels=[0.1, 0.5, 0.9]
563
+ )
564
+
565
+ # Convert predictions to numpy and denormalize
566
+ next_mean = next_mean.detach().cpu().numpy()
567
+ next_quantiles = next_quantiles.detach().cpu().numpy()
568
+
569
+ # Denormalize predictions
570
+ next_mean_pred = scaler.inverse_transform(next_mean.reshape(-1, 1)).flatten()
571
+ next_lower = scaler.inverse_transform(next_quantiles[0, :, 0].reshape(-1, 1)).flatten()
572
+ next_upper = scaler.inverse_transform(next_quantiles[0, :, 2].reshape(-1, 1)).flatten()
573
+
574
+ # Calculate standard deviation
575
+ next_std_pred = (next_upper - next_lower) / (2 * 1.645)
576
+
577
+ # Append predictions
578
+ extended_mean_pred = np.concatenate([extended_mean_pred, next_mean_pred])
579
+ extended_std_pred = np.concatenate([extended_std_pred, next_std_pred])
580
+
581
+ # Update remaining days
582
+ remaining_days -= len(next_mean_pred)
583
+ if remaining_days <= 0:
584
+ break
585
+
586
+ # Trim to exact prediction length if needed
587
+ mean_pred = extended_mean_pred[:prediction_days]
588
+ std_pred = extended_std_pred[:prediction_days]
589
 
590
  except Exception as e:
591
  print(f"Chronos prediction error: {str(e)}")