Tonic commited on
Commit
cdc40a2
·
1 Parent(s): db47567

correctly pass tensor shapes, adds validation checks adds dtype matching and validation

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -360,8 +360,15 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
360
  # First try with predict_quantiles
361
  try:
362
  print("Trying predict_quantiles...")
 
 
 
 
 
 
 
363
  quantiles, mean = pipe.predict_quantiles(
364
- context=context,
365
  prediction_length=actual_prediction_length,
366
  quantile_levels=[0.1, 0.5, 0.9] # 10th, 50th, and 90th percentiles
367
  )
@@ -388,8 +395,11 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
388
  print("Falling back to predict...")
389
 
390
  # Fallback to predict if predict_quantiles fails
 
 
 
391
  prediction = pipe.predict(
392
- context=context,
393
  prediction_length=actual_prediction_length,
394
  num_samples=100
395
  )
 
360
  # First try with predict_quantiles
361
  try:
362
  print("Trying predict_quantiles...")
363
+ # Convert context to the correct format
364
+ context_tensor = context.to(device=pipe.model.device, dtype=torch.float16)
365
+
366
+ # Ensure context is properly formatted
367
+ if len(context_tensor.shape) != 3:
368
+ raise ValueError(f"Expected 3D tensor, got shape {context_tensor.shape}")
369
+
370
  quantiles, mean = pipe.predict_quantiles(
371
+ context=context_tensor,
372
  prediction_length=actual_prediction_length,
373
  quantile_levels=[0.1, 0.5, 0.9] # 10th, 50th, and 90th percentiles
374
  )
 
395
  print("Falling back to predict...")
396
 
397
  # Fallback to predict if predict_quantiles fails
398
+ # Convert context to the correct format
399
+ context_tensor = context.to(device=pipe.model.device, dtype=torch.float16)
400
+
401
  prediction = pipe.predict(
402
+ context=context_tensor,
403
  prediction_length=actual_prediction_length,
404
  num_samples=100
405
  )