Tonic commited on
Commit
a65d91d
·
1 Parent(s): 17ea549

extend the rolling window predictions to all timescales

Browse files
Files changed (1) hide show
  1. app.py +44 -11
app.py CHANGED
@@ -323,15 +323,22 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
323
 
324
  # Adjust prediction length based on timeframe
325
  if timeframe == "1d":
326
- max_prediction_length = 64
 
327
  elif timeframe == "1h":
328
- max_prediction_length = 168
 
329
  else: # 15m
330
- max_prediction_length = 192
 
331
 
332
- actual_prediction_length = min(prediction_days, max_prediction_length) if timeframe == "1d" else \
333
- min(prediction_days * 24, max_prediction_length) if timeframe == "1h" else \
334
- min(prediction_days * 96, max_prediction_length)
 
 
 
 
335
 
336
  actual_prediction_length = max(1, actual_prediction_length)
337
 
@@ -543,7 +550,6 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
543
 
544
  for step in range(steps_needed):
545
  # Use the last window_size points as context for next prediction
546
- window_size = min(64, len(extended_mean_pred))
547
  context_window = extended_mean_pred[-window_size:]
548
 
549
  # Normalize the context window
@@ -554,11 +560,19 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
554
  if len(context.shape) == 1:
555
  context = context.unsqueeze(0)
556
 
 
 
 
 
 
 
 
 
557
  # Make prediction for next window
558
  with torch.amp.autocast('cuda'):
559
  next_quantiles, next_mean = pipe.predict_quantiles(
560
  context=context,
561
- prediction_length=min(actual_prediction_length, remaining_days),
562
  quantile_levels=[0.1, 0.5, 0.9]
563
  )
564
 
@@ -574,18 +588,37 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
574
  # Calculate standard deviation
575
  next_std_pred = (next_upper - next_lower) / (2 * 1.645)
576
 
 
 
 
 
 
 
577
  # Append predictions
578
  extended_mean_pred = np.concatenate([extended_mean_pred, next_mean_pred])
579
  extended_std_pred = np.concatenate([extended_std_pred, next_std_pred])
580
 
581
  # Update remaining days
582
- remaining_days -= len(next_mean_pred)
 
 
 
 
 
 
583
  if remaining_days <= 0:
584
  break
585
 
586
  # Trim to exact prediction length if needed
587
- mean_pred = extended_mean_pred[:prediction_days]
588
- std_pred = extended_std_pred[:prediction_days]
 
 
 
 
 
 
 
589
 
590
  except Exception as e:
591
  print(f"Chronos prediction error: {str(e)}")
 
323
 
324
  # Adjust prediction length based on timeframe
325
  if timeframe == "1d":
326
+ max_prediction_length = 64 # Chronos maximum
327
+ window_size = 64 # Use full context window
328
  elif timeframe == "1h":
329
+ max_prediction_length = 64 # Chronos maximum
330
+ window_size = 64 # Use full context window
331
  else: # 15m
332
+ max_prediction_length = 64 # Chronos maximum
333
+ window_size = 64 # Use full context window
334
 
335
+ # Calculate actual prediction length based on timeframe
336
+ if timeframe == "1d":
337
+ actual_prediction_length = min(prediction_days, max_prediction_length)
338
+ elif timeframe == "1h":
339
+ actual_prediction_length = min(prediction_days * 24, max_prediction_length)
340
+ else: # 15m
341
+ actual_prediction_length = min(prediction_days * 96, max_prediction_length)
342
 
343
  actual_prediction_length = max(1, actual_prediction_length)
344
 
 
550
 
551
  for step in range(steps_needed):
552
  # Use the last window_size points as context for next prediction
 
553
  context_window = extended_mean_pred[-window_size:]
554
 
555
  # Normalize the context window
 
560
  if len(context.shape) == 1:
561
  context = context.unsqueeze(0)
562
 
563
+ # Calculate next prediction length based on timeframe
564
+ if timeframe == "1d":
565
+ next_length = min(max_prediction_length, remaining_days)
566
+ elif timeframe == "1h":
567
+ next_length = min(max_prediction_length, remaining_days * 24)
568
+ else: # 15m
569
+ next_length = min(max_prediction_length, remaining_days * 96)
570
+
571
  # Make prediction for next window
572
  with torch.amp.autocast('cuda'):
573
  next_quantiles, next_mean = pipe.predict_quantiles(
574
  context=context,
575
+ prediction_length=next_length,
576
  quantile_levels=[0.1, 0.5, 0.9]
577
  )
578
 
 
588
  # Calculate standard deviation
589
  next_std_pred = (next_upper - next_lower) / (2 * 1.645)
590
 
591
+ # Apply exponential smoothing to reduce prediction drift
592
+ if step > 0:
593
+ alpha = 0.3 # Smoothing factor
594
+ next_mean_pred = alpha * next_mean_pred + (1 - alpha) * extended_mean_pred[-len(next_mean_pred):]
595
+ next_std_pred = alpha * next_std_pred + (1 - alpha) * extended_std_pred[-len(next_std_pred):]
596
+
597
  # Append predictions
598
  extended_mean_pred = np.concatenate([extended_mean_pred, next_mean_pred])
599
  extended_std_pred = np.concatenate([extended_std_pred, next_std_pred])
600
 
601
  # Update remaining days
602
+ if timeframe == "1d":
603
+ remaining_days -= len(next_mean_pred)
604
+ elif timeframe == "1h":
605
+ remaining_days -= len(next_mean_pred) / 24
606
+ else: # 15m
607
+ remaining_days -= len(next_mean_pred) / 96
608
+
609
  if remaining_days <= 0:
610
  break
611
 
612
  # Trim to exact prediction length if needed
613
+ if timeframe == "1d":
614
+ mean_pred = extended_mean_pred[:prediction_days]
615
+ std_pred = extended_std_pred[:prediction_days]
616
+ elif timeframe == "1h":
617
+ mean_pred = extended_mean_pred[:prediction_days * 24]
618
+ std_pred = extended_std_pred[:prediction_days * 24]
619
+ else: # 15m
620
+ mean_pred = extended_mean_pred[:prediction_days * 96]
621
+ std_pred = extended_std_pred[:prediction_days * 96]
622
 
623
  except Exception as e:
624
  print(f"Chronos prediction error: {str(e)}")