Spaces:
Running
on
Zero
Running
on
Zero
add tabbed interface
Browse files
app.py
CHANGED
@@ -141,7 +141,7 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
141 |
|
142 |
Args:
|
143 |
symbol (str): Stock symbol
|
144 |
-
timeframe (str): Data timeframe
|
145 |
prediction_days (int): Number of days to predict
|
146 |
strategy (str): Prediction strategy to use
|
147 |
|
@@ -162,26 +162,39 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
162 |
# Make prediction with GPU acceleration
|
163 |
pipe = load_pipeline()
|
164 |
|
165 |
-
#
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
with torch.inference_mode():
|
169 |
prediction = pipe.predict(
|
170 |
context=context,
|
171 |
-
prediction_length=
|
172 |
num_samples=100
|
173 |
).detach().cpu().numpy()
|
174 |
|
175 |
mean_pred = prediction.mean(axis=0)
|
176 |
std_pred = prediction.std(axis=0)
|
177 |
|
178 |
-
# If we had to limit the prediction
|
179 |
-
if
|
180 |
last_pred = mean_pred[-1]
|
181 |
last_std = std_pred[-1]
|
182 |
-
extension = np.array([last_pred * (1 + np.random.normal(0, last_std, prediction_days -
|
183 |
mean_pred = np.concatenate([mean_pred, extension])
|
184 |
-
std_pred = np.concatenate([std_pred, np.full(prediction_days -
|
185 |
|
186 |
except Exception as e:
|
187 |
print(f"Chronos prediction failed: {str(e)}")
|
@@ -203,9 +216,14 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
203 |
mean_pred = np.array([last_price * (1 + trend * volatility * i) for i in range(1, prediction_days + 1)])
|
204 |
std_pred = np.array([volatility * last_price * i for i in range(1, prediction_days + 1)])
|
205 |
|
206 |
-
# Create prediction dates
|
207 |
last_date = df.index[-1]
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# Create visualization
|
211 |
fig = make_subplots(rows=3, cols=1,
|
@@ -265,9 +283,9 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
265 |
row=3, col=1
|
266 |
)
|
267 |
|
268 |
-
# Update layout
|
269 |
fig.update_layout(
|
270 |
-
title=f'{symbol} Analysis and Prediction',
|
271 |
xaxis_title='Date',
|
272 |
yaxis_title='Price',
|
273 |
height=1000,
|
@@ -280,9 +298,10 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
|
|
280 |
# Add prediction information to signals
|
281 |
signals.update({
|
282 |
"symbol": symbol,
|
|
|
283 |
"prediction": mean_pred.tolist(),
|
284 |
"confidence": std_pred.tolist(),
|
285 |
-
"dates": pred_dates.strftime('%Y-%m-%d').tolist(),
|
286 |
"strategy_used": strategy
|
287 |
})
|
288 |
|
@@ -316,54 +335,134 @@ def calculate_trading_signals(df: pd.DataFrame) -> Dict:
|
|
316 |
return signals
|
317 |
|
318 |
def create_interface():
|
319 |
-
"""Create the Gradio interface"""
|
320 |
with gr.Blocks(title="Structured Product Analysis") as demo:
|
321 |
gr.Markdown("# Structured Product Analysis")
|
322 |
gr.Markdown("Analyze stocks for inclusion in structured financial products with extended time horizons.")
|
323 |
|
324 |
-
with gr.
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
-
gr.
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
-
gr.
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
def analyze_stock(symbol, timeframe, prediction_days, lookback_days, strategy):
|
369 |
signals, fig = make_prediction(symbol, timeframe, prediction_days, strategy)
|
@@ -400,10 +499,25 @@ def create_interface():
|
|
400 |
|
401 |
return signals, fig, product_metrics, risk_metrics, sector_metrics
|
402 |
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
)
|
408 |
|
409 |
return demo
|
|
|
141 |
|
142 |
Args:
|
143 |
symbol (str): Stock symbol
|
144 |
+
timeframe (str): Data timeframe ('1d', '1h', '15m')
|
145 |
prediction_days (int): Number of days to predict
|
146 |
strategy (str): Prediction strategy to use
|
147 |
|
|
|
162 |
# Make prediction with GPU acceleration
|
163 |
pipe = load_pipeline()
|
164 |
|
165 |
+
# Adjust prediction length based on timeframe
|
166 |
+
if timeframe == "1d":
|
167 |
+
max_prediction_length = 64 # Maximum 64 days for daily data
|
168 |
+
elif timeframe == "1h":
|
169 |
+
max_prediction_length = 168 # Maximum 7 days (168 hours) for hourly data
|
170 |
+
else: # 15m
|
171 |
+
max_prediction_length = 192 # Maximum 2 days (192 15-minute intervals) for 15m data
|
172 |
+
|
173 |
+
# Convert prediction_days to appropriate intervals
|
174 |
+
if timeframe == "1d":
|
175 |
+
actual_prediction_length = min(prediction_days, max_prediction_length)
|
176 |
+
elif timeframe == "1h":
|
177 |
+
actual_prediction_length = min(prediction_days * 24, max_prediction_length)
|
178 |
+
else: # 15m
|
179 |
+
actual_prediction_length = min(prediction_days * 96, max_prediction_length) # 96 intervals per day
|
180 |
|
181 |
with torch.inference_mode():
|
182 |
prediction = pipe.predict(
|
183 |
context=context,
|
184 |
+
prediction_length=actual_prediction_length,
|
185 |
num_samples=100
|
186 |
).detach().cpu().numpy()
|
187 |
|
188 |
mean_pred = prediction.mean(axis=0)
|
189 |
std_pred = prediction.std(axis=0)
|
190 |
|
191 |
+
# If we had to limit the prediction length, extend the prediction
|
192 |
+
if actual_prediction_length < prediction_days:
|
193 |
last_pred = mean_pred[-1]
|
194 |
last_std = std_pred[-1]
|
195 |
+
extension = np.array([last_pred * (1 + np.random.normal(0, last_std, prediction_days - actual_prediction_length))])
|
196 |
mean_pred = np.concatenate([mean_pred, extension])
|
197 |
+
std_pred = np.concatenate([std_pred, np.full(prediction_days - actual_prediction_length, last_std)])
|
198 |
|
199 |
except Exception as e:
|
200 |
print(f"Chronos prediction failed: {str(e)}")
|
|
|
216 |
mean_pred = np.array([last_price * (1 + trend * volatility * i) for i in range(1, prediction_days + 1)])
|
217 |
std_pred = np.array([volatility * last_price * i for i in range(1, prediction_days + 1)])
|
218 |
|
219 |
+
# Create prediction dates based on timeframe
|
220 |
last_date = df.index[-1]
|
221 |
+
if timeframe == "1d":
|
222 |
+
pred_dates = pd.date_range(start=last_date + timedelta(days=1), periods=prediction_days)
|
223 |
+
elif timeframe == "1h":
|
224 |
+
pred_dates = pd.date_range(start=last_date + timedelta(hours=1), periods=prediction_days * 24)
|
225 |
+
else: # 15m
|
226 |
+
pred_dates = pd.date_range(start=last_date + timedelta(minutes=15), periods=prediction_days * 96)
|
227 |
|
228 |
# Create visualization
|
229 |
fig = make_subplots(rows=3, cols=1,
|
|
|
283 |
row=3, col=1
|
284 |
)
|
285 |
|
286 |
+
# Update layout with timeframe-specific settings
|
287 |
fig.update_layout(
|
288 |
+
title=f'{symbol} {timeframe} Analysis and Prediction',
|
289 |
xaxis_title='Date',
|
290 |
yaxis_title='Price',
|
291 |
height=1000,
|
|
|
298 |
# Add prediction information to signals
|
299 |
signals.update({
|
300 |
"symbol": symbol,
|
301 |
+
"timeframe": timeframe,
|
302 |
"prediction": mean_pred.tolist(),
|
303 |
"confidence": std_pred.tolist(),
|
304 |
+
"dates": pred_dates.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
305 |
"strategy_used": strategy
|
306 |
})
|
307 |
|
|
|
335 |
return signals
|
336 |
|
337 |
def create_interface():
|
338 |
+
"""Create the Gradio interface with separate tabs for different timeframes"""
|
339 |
with gr.Blocks(title="Structured Product Analysis") as demo:
|
340 |
gr.Markdown("# Structured Product Analysis")
|
341 |
gr.Markdown("Analyze stocks for inclusion in structured financial products with extended time horizons.")
|
342 |
|
343 |
+
with gr.Tabs() as tabs:
|
344 |
+
# Daily Analysis Tab
|
345 |
+
with gr.TabItem("Daily Analysis"):
|
346 |
+
with gr.Row():
|
347 |
+
with gr.Column():
|
348 |
+
daily_symbol = gr.Textbox(label="Stock Symbol (e.g., AAPL)", value="AAPL")
|
349 |
+
daily_prediction_days = gr.Slider(
|
350 |
+
minimum=1,
|
351 |
+
maximum=365,
|
352 |
+
value=30,
|
353 |
+
step=1,
|
354 |
+
label="Days to Predict"
|
355 |
+
)
|
356 |
+
daily_lookback_days = gr.Slider(
|
357 |
+
minimum=1,
|
358 |
+
maximum=3650,
|
359 |
+
value=365,
|
360 |
+
step=1,
|
361 |
+
label="Historical Lookback (Days)"
|
362 |
+
)
|
363 |
+
daily_strategy = gr.Dropdown(
|
364 |
+
choices=["chronos", "technical"],
|
365 |
+
label="Prediction Strategy",
|
366 |
+
value="chronos"
|
367 |
+
)
|
368 |
+
daily_predict_btn = gr.Button("Analyze Stock")
|
369 |
+
|
370 |
+
with gr.Column():
|
371 |
+
daily_plot = gr.Plot(label="Analysis and Prediction")
|
372 |
+
daily_signals = gr.JSON(label="Trading Signals")
|
373 |
+
|
374 |
+
with gr.Row():
|
375 |
+
with gr.Column():
|
376 |
+
gr.Markdown("### Structured Product Metrics")
|
377 |
+
daily_metrics = gr.JSON(label="Product Metrics")
|
378 |
+
|
379 |
+
gr.Markdown("### Risk Analysis")
|
380 |
+
daily_risk_metrics = gr.JSON(label="Risk Metrics")
|
381 |
+
|
382 |
+
gr.Markdown("### Sector Analysis")
|
383 |
+
daily_sector_metrics = gr.JSON(label="Sector Metrics")
|
384 |
|
385 |
+
# Hourly Analysis Tab
|
386 |
+
with gr.TabItem("Hourly Analysis"):
|
387 |
+
with gr.Row():
|
388 |
+
with gr.Column():
|
389 |
+
hourly_symbol = gr.Textbox(label="Stock Symbol (e.g., AAPL)", value="AAPL")
|
390 |
+
hourly_prediction_days = gr.Slider(
|
391 |
+
minimum=1,
|
392 |
+
maximum=7, # Limited to 7 days for hourly predictions
|
393 |
+
value=3,
|
394 |
+
step=1,
|
395 |
+
label="Days to Predict"
|
396 |
+
)
|
397 |
+
hourly_lookback_days = gr.Slider(
|
398 |
+
minimum=1,
|
399 |
+
maximum=30, # Limited to 30 days for hourly data
|
400 |
+
value=14,
|
401 |
+
step=1,
|
402 |
+
label="Historical Lookback (Days)"
|
403 |
+
)
|
404 |
+
hourly_strategy = gr.Dropdown(
|
405 |
+
choices=["chronos", "technical"],
|
406 |
+
label="Prediction Strategy",
|
407 |
+
value="chronos"
|
408 |
+
)
|
409 |
+
hourly_predict_btn = gr.Button("Analyze Stock")
|
410 |
+
|
411 |
+
with gr.Column():
|
412 |
+
hourly_plot = gr.Plot(label="Analysis and Prediction")
|
413 |
+
hourly_signals = gr.JSON(label="Trading Signals")
|
414 |
|
415 |
+
with gr.Row():
|
416 |
+
with gr.Column():
|
417 |
+
gr.Markdown("### Structured Product Metrics")
|
418 |
+
hourly_metrics = gr.JSON(label="Product Metrics")
|
419 |
+
|
420 |
+
gr.Markdown("### Risk Analysis")
|
421 |
+
hourly_risk_metrics = gr.JSON(label="Risk Metrics")
|
422 |
+
|
423 |
+
gr.Markdown("### Sector Analysis")
|
424 |
+
hourly_sector_metrics = gr.JSON(label="Sector Metrics")
|
425 |
+
|
426 |
+
# 15-Minute Analysis Tab
|
427 |
+
with gr.TabItem("15-Minute Analysis"):
|
428 |
+
with gr.Row():
|
429 |
+
with gr.Column():
|
430 |
+
min15_symbol = gr.Textbox(label="Stock Symbol (e.g., AAPL)", value="AAPL")
|
431 |
+
min15_prediction_days = gr.Slider(
|
432 |
+
minimum=1,
|
433 |
+
maximum=2, # Limited to 2 days for 15-minute predictions
|
434 |
+
value=1,
|
435 |
+
step=1,
|
436 |
+
label="Days to Predict"
|
437 |
+
)
|
438 |
+
min15_lookback_days = gr.Slider(
|
439 |
+
minimum=1,
|
440 |
+
maximum=5, # Limited to 5 days for 15-minute data
|
441 |
+
value=3,
|
442 |
+
step=1,
|
443 |
+
label="Historical Lookback (Days)"
|
444 |
+
)
|
445 |
+
min15_strategy = gr.Dropdown(
|
446 |
+
choices=["chronos", "technical"],
|
447 |
+
label="Prediction Strategy",
|
448 |
+
value="chronos"
|
449 |
+
)
|
450 |
+
min15_predict_btn = gr.Button("Analyze Stock")
|
451 |
+
|
452 |
+
with gr.Column():
|
453 |
+
min15_plot = gr.Plot(label="Analysis and Prediction")
|
454 |
+
min15_signals = gr.JSON(label="Trading Signals")
|
455 |
|
456 |
+
with gr.Row():
|
457 |
+
with gr.Column():
|
458 |
+
gr.Markdown("### Structured Product Metrics")
|
459 |
+
min15_metrics = gr.JSON(label="Product Metrics")
|
460 |
+
|
461 |
+
gr.Markdown("### Risk Analysis")
|
462 |
+
min15_risk_metrics = gr.JSON(label="Risk Metrics")
|
463 |
+
|
464 |
+
gr.Markdown("### Sector Analysis")
|
465 |
+
min15_sector_metrics = gr.JSON(label="Sector Metrics")
|
466 |
|
467 |
def analyze_stock(symbol, timeframe, prediction_days, lookback_days, strategy):
|
468 |
signals, fig = make_prediction(symbol, timeframe, prediction_days, strategy)
|
|
|
499 |
|
500 |
return signals, fig, product_metrics, risk_metrics, sector_metrics
|
501 |
|
502 |
+
# Daily analysis button click
|
503 |
+
daily_predict_btn.click(
|
504 |
+
fn=lambda s, pd, ld, st: analyze_stock(s, "1d", pd, ld, st),
|
505 |
+
inputs=[daily_symbol, daily_prediction_days, daily_lookback_days, daily_strategy],
|
506 |
+
outputs=[daily_signals, daily_plot, daily_metrics, daily_risk_metrics, daily_sector_metrics]
|
507 |
+
)
|
508 |
+
|
509 |
+
# Hourly analysis button click
|
510 |
+
hourly_predict_btn.click(
|
511 |
+
fn=lambda s, pd, ld, st: analyze_stock(s, "1h", pd, ld, st),
|
512 |
+
inputs=[hourly_symbol, hourly_prediction_days, hourly_lookback_days, hourly_strategy],
|
513 |
+
outputs=[hourly_signals, hourly_plot, hourly_metrics, hourly_risk_metrics, hourly_sector_metrics]
|
514 |
+
)
|
515 |
+
|
516 |
+
# 15-minute analysis button click
|
517 |
+
min15_predict_btn.click(
|
518 |
+
fn=lambda s, pd, ld, st: analyze_stock(s, "15m", pd, ld, st),
|
519 |
+
inputs=[min15_symbol, min15_prediction_days, min15_lookback_days, min15_strategy],
|
520 |
+
outputs=[min15_signals, min15_plot, min15_metrics, min15_risk_metrics, min15_sector_metrics]
|
521 |
)
|
522 |
|
523 |
return demo
|