Tonic commited on
Commit
db1cbec
·
1 Parent(s): 1702ce4

limit predictions to 64 days , adds fallbacks, large number of days algorithm to produce financial metrics

Browse files
Files changed (1) hide show
  1. app.py +71 -27
app.py CHANGED
@@ -12,24 +12,47 @@ import plotly.express as px
12
  from typing import Dict, List, Tuple, Optional
13
  import json
14
  import spaces
 
15
 
16
  # Initialize global variables
17
  pipeline = None
18
  scaler = MinMaxScaler(feature_range=(-1, 1))
19
  scaler.fit_transform([[-1, 1]])
20
 
 
 
 
 
 
 
21
  @spaces.GPU
22
  def load_pipeline():
23
  """Load the Chronos model with GPU configuration"""
24
  global pipeline
25
- if pipeline is None:
26
- pipeline = ChronosPipeline.from_pretrained(
27
- "amazon/chronos-t5-large", # Using the largest model for best performance
28
- device_map="cuda", # Using CUDA for GPU acceleration
29
- torch_dtype=torch.bfloat16 # Using bfloat16 for better memory efficiency
30
- )
31
- pipeline.model = pipeline.model.eval()
32
- return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def get_historical_data(symbol: str, timeframe: str = "1d", lookback_days: int = 365) -> pd.DataFrame:
35
  """
@@ -140,24 +163,42 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
140
  df = get_historical_data(symbol, timeframe)
141
 
142
  if strategy == "chronos":
143
- # Prepare data for Chronos
144
- returns = df['Returns'].values
145
- normalized_returns = (returns - returns.mean()) / returns.std()
146
- context = torch.tensor(normalized_returns.reshape(-1, 1), dtype=torch.float32)
147
-
148
- # Make prediction with GPU acceleration
149
- pipe = load_pipeline()
150
- with torch.inference_mode():
151
- prediction = pipe.predict(
152
- context=context,
153
- prediction_length=prediction_days,
154
- num_samples=100
155
- ).detach().cpu().numpy()
156
-
157
- mean_pred = prediction.mean(axis=0)
158
- std_pred = prediction.std(axis=0)
159
-
160
- elif strategy == "technical":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # Technical analysis based prediction
162
  last_price = df['Close'].iloc[-1]
163
  rsi = df['RSI'].iloc[-1]
@@ -251,13 +292,16 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
251
  "symbol": symbol,
252
  "prediction": mean_pred.tolist(),
253
  "confidence": std_pred.tolist(),
254
- "dates": pred_dates.strftime('%Y-%m-%d').tolist()
 
255
  })
256
 
257
  return signals, fig
258
 
259
  except Exception as e:
260
  raise Exception(f"Prediction error: {str(e)}")
 
 
261
 
262
  def calculate_trading_signals(df: pd.DataFrame) -> Dict:
263
  """Calculate trading signals based on technical indicators"""
 
12
  from typing import Dict, List, Tuple, Optional
13
  import json
14
  import spaces
15
+ import gc
16
 
17
  # Initialize global variables
18
  pipeline = None
19
  scaler = MinMaxScaler(feature_range=(-1, 1))
20
  scaler.fit_transform([[-1, 1]])
21
 
22
+ def clear_gpu_memory():
23
+ """Clear GPU memory cache"""
24
+ if torch.cuda.is_available():
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
  @spaces.GPU
29
  def load_pipeline():
30
  """Load the Chronos model with GPU configuration"""
31
  global pipeline
32
+ try:
33
+ if pipeline is None:
34
+ clear_gpu_memory()
35
+ pipeline = ChronosPipeline.from_pretrained(
36
+ "amazon/chronos-t5-large",
37
+ device_map="gpu", # Let the model decide the best device mapping
38
+ torch_dtype=torch.float16,
39
+ low_cpu_mem_usage=True
40
+ )
41
+ pipeline.model = pipeline.model.eval()
42
+ return pipeline
43
+ except Exception as e:
44
+ print(f"Error loading pipeline: {str(e)}")
45
+ # Fallback to CPU if GPU fails
46
+ if "cuda" in str(e).lower():
47
+ print("Falling back to CPU mode")
48
+ pipeline = ChronosPipeline.from_pretrained(
49
+ "amazon/chronos-t5-large",
50
+ device_map="cpu",
51
+ torch_dtype=torch.float32,
52
+ low_cpu_mem_usage=True
53
+ )
54
+ pipeline.model = pipeline.model.eval()
55
+ return pipeline
56
 
57
  def get_historical_data(symbol: str, timeframe: str = "1d", lookback_days: int = 365) -> pd.DataFrame:
58
  """
 
163
  df = get_historical_data(symbol, timeframe)
164
 
165
  if strategy == "chronos":
166
+ try:
167
+ # Prepare data for Chronos
168
+ returns = df['Returns'].values
169
+ normalized_returns = (returns - returns.mean()) / returns.std()
170
+ context = torch.tensor(normalized_returns.reshape(-1, 1), dtype=torch.float32)
171
+
172
+ # Make prediction with GPU acceleration
173
+ pipe = load_pipeline()
174
+
175
+ # Limit prediction length to avoid memory issues
176
+ actual_prediction_days = min(prediction_days, 64)
177
+
178
+ with torch.inference_mode():
179
+ prediction = pipe.predict(
180
+ context=context,
181
+ prediction_length=actual_prediction_days,
182
+ num_samples=100
183
+ ).detach().cpu().numpy()
184
+
185
+ mean_pred = prediction.mean(axis=0)
186
+ std_pred = prediction.std(axis=0)
187
+
188
+ # If we had to limit the prediction days, extend the prediction
189
+ if actual_prediction_days < prediction_days:
190
+ last_pred = mean_pred[-1]
191
+ last_std = std_pred[-1]
192
+ extension = np.array([last_pred * (1 + np.random.normal(0, last_std, prediction_days - actual_prediction_days))])
193
+ mean_pred = np.concatenate([mean_pred, extension])
194
+ std_pred = np.concatenate([std_pred, np.full(prediction_days - actual_prediction_days, last_std)])
195
+
196
+ except Exception as e:
197
+ print(f"Chronos prediction failed: {str(e)}")
198
+ print("Falling back to technical analysis")
199
+ strategy = "technical"
200
+
201
+ if strategy == "technical":
202
  # Technical analysis based prediction
203
  last_price = df['Close'].iloc[-1]
204
  rsi = df['RSI'].iloc[-1]
 
292
  "symbol": symbol,
293
  "prediction": mean_pred.tolist(),
294
  "confidence": std_pred.tolist(),
295
+ "dates": pred_dates.strftime('%Y-%m-%d').tolist(),
296
+ "strategy_used": strategy
297
  })
298
 
299
  return signals, fig
300
 
301
  except Exception as e:
302
  raise Exception(f"Prediction error: {str(e)}")
303
+ finally:
304
+ clear_gpu_memory()
305
 
306
  def calculate_trading_signals(df: pd.DataFrame) -> Dict:
307
  """Calculate trading signals based on technical indicators"""