Tonic commited on
Commit
248f443
·
1 Parent(s): 165a173

improves error handling by returning chronos errors

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -58,6 +58,7 @@ def load_pipeline():
58
  try:
59
  if pipeline is None:
60
  clear_gpu_memory()
 
61
  pipeline = ChronosPipeline.from_pretrained(
62
  "amazon/chronos-t5-large",
63
  device_map="auto", # Let the machine choose the best device
@@ -65,6 +66,7 @@ def load_pipeline():
65
  low_cpu_mem_usage=True
66
  )
67
  pipeline.model = pipeline.model.eval()
 
68
  return pipeline
69
  except Exception as e:
70
  print(f"Error loading pipeline: {str(e)}")
@@ -240,7 +242,7 @@ def get_historical_data(symbol: str, timeframe: str = "1d", lookback_days: int =
240
  def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
241
  """Calculate Relative Strength Index"""
242
  # Handle None values by forward filling
243
- prices = prices.fillna(method='ffill').fillna(method='bfill')
244
  delta = prices.diff()
245
  gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
246
  loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
@@ -250,7 +252,7 @@ def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
250
  def calculate_macd(prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> Tuple[pd.Series, pd.Series]:
251
  """Calculate MACD and Signal line"""
252
  # Handle None values by forward filling
253
- prices = prices.fillna(method='ffill').fillna(method='bfill')
254
  exp1 = prices.ewm(span=fast, adjust=False).mean()
255
  exp2 = prices.ewm(span=slow, adjust=False).mean()
256
  macd = exp1 - exp2
@@ -260,7 +262,7 @@ def calculate_macd(prices: pd.Series, fast: int = 12, slow: int = 26, signal: in
260
  def calculate_bollinger_bands(prices: pd.Series, period: int = 20, std_dev: int = 2) -> Tuple[pd.Series, pd.Series, pd.Series]:
261
  """Calculate Bollinger Bands"""
262
  # Handle None values by forward filling
263
- prices = prices.fillna(method='ffill').fillna(method='bfill')
264
  middle_band = prices.rolling(window=period).mean()
265
  std = prices.rolling(window=period).std()
266
  upper_band = middle_band + (std * std_dev)
@@ -330,23 +332,31 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
330
  actual_prediction_length = max(1, actual_prediction_length)
331
 
332
  with torch.inference_mode():
333
- prediction = pipe.predict(
334
- context=context,
335
- prediction_length=actual_prediction_length,
336
- num_samples=100
337
- ).detach().cpu().numpy()
338
-
339
- # Denormalize predictions
340
- mean_pred = scaler.inverse_transform(prediction.mean(axis=0).reshape(-1, 1)).flatten()
341
- std_pred = prediction.std(axis=0) * (scaler.data_max_ - scaler.data_min_)
342
-
343
- # If we had to limit the prediction length, extend the prediction
344
- if actual_prediction_length < prediction_days:
345
- last_pred = mean_pred[-1]
346
- last_std = std_pred[-1]
347
- extension = np.array([last_pred * (1 + np.random.normal(0, last_std, prediction_days - actual_prediction_length))])
348
- mean_pred = np.concatenate([mean_pred, extension])
349
- std_pred = np.concatenate([std_pred, np.full(prediction_days - actual_prediction_length, last_std)])
 
 
 
 
 
 
 
 
350
 
351
  except Exception as e:
352
  print(f"Chronos prediction failed: {str(e)}")
 
58
  try:
59
  if pipeline is None:
60
  clear_gpu_memory()
61
+ print("Loading Chronos model...")
62
  pipeline = ChronosPipeline.from_pretrained(
63
  "amazon/chronos-t5-large",
64
  device_map="auto", # Let the machine choose the best device
 
66
  low_cpu_mem_usage=True
67
  )
68
  pipeline.model = pipeline.model.eval()
69
+ print("Chronos model loaded successfully")
70
  return pipeline
71
  except Exception as e:
72
  print(f"Error loading pipeline: {str(e)}")
 
242
  def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
243
  """Calculate Relative Strength Index"""
244
  # Handle None values by forward filling
245
+ prices = prices.ffill().bfill()
246
  delta = prices.diff()
247
  gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
248
  loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
 
252
  def calculate_macd(prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> Tuple[pd.Series, pd.Series]:
253
  """Calculate MACD and Signal line"""
254
  # Handle None values by forward filling
255
+ prices = prices.ffill().bfill()
256
  exp1 = prices.ewm(span=fast, adjust=False).mean()
257
  exp2 = prices.ewm(span=slow, adjust=False).mean()
258
  macd = exp1 - exp2
 
262
  def calculate_bollinger_bands(prices: pd.Series, period: int = 20, std_dev: int = 2) -> Tuple[pd.Series, pd.Series, pd.Series]:
263
  """Calculate Bollinger Bands"""
264
  # Handle None values by forward filling
265
+ prices = prices.ffill().bfill()
266
  middle_band = prices.rolling(window=period).mean()
267
  std = prices.rolling(window=period).std()
268
  upper_band = middle_band + (std * std_dev)
 
332
  actual_prediction_length = max(1, actual_prediction_length)
333
 
334
  with torch.inference_mode():
335
+ try:
336
+ prediction = pipe.predict(
337
+ context=context,
338
+ prediction_length=actual_prediction_length,
339
+ num_samples=100
340
+ ).detach().cpu().numpy()
341
+
342
+ if prediction is None or prediction.size == 0:
343
+ raise ValueError("Chronos returned empty prediction")
344
+
345
+ # Denormalize predictions
346
+ mean_pred = scaler.inverse_transform(prediction.mean(axis=0).reshape(-1, 1)).flatten()
347
+ std_pred = prediction.std(axis=0) * (scaler.data_max_ - scaler.data_min_)
348
+
349
+ # If we had to limit the prediction length, extend the prediction
350
+ if actual_prediction_length < prediction_days:
351
+ last_pred = mean_pred[-1]
352
+ last_std = std_pred[-1]
353
+ extension = np.array([last_pred * (1 + np.random.normal(0, last_std, prediction_days - actual_prediction_length))])
354
+ mean_pred = np.concatenate([mean_pred, extension])
355
+ std_pred = np.concatenate([std_pred, np.full(prediction_days - actual_prediction_length, last_std)])
356
+
357
+ except Exception as e:
358
+ print(f"Chronos prediction error: {str(e)}")
359
+ raise
360
 
361
  except Exception as e:
362
  print(f"Chronos prediction failed: {str(e)}")