StonksPredictor / app.py
sebasfb99's picture
Update app.py
c3585da verified
raw
history blame
8.95 kB
import gradio as gr
import torch
from chronos import ChronosPipeline
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from sklearn.metrics import mean_absolute_error, mean_squared_error
import tempfile
def get_popular_tickers():
return [
"AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
"JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC"
]
def predict_stock(ticker, train_data_points, prediction_days):
try:
# Asegurar que los parámetros sean enteros
train_data_points = int(train_data_points)
prediction_days = int(prediction_days)
# Configurar el pipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-mini",
device_map="cpu",
torch_dtype=torch.float32
)
# Obtener la cantidad máxima de datos disponibles
stock = yf.Ticker(ticker)
hist = stock.history(period="max")
stock_prices = hist[['Close']].reset_index()
df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'})
total_points = len(df)
# Asegurar que el número de datos de entrenamiento no exceda el total disponible
train_data_points = min(train_data_points, total_points)
# Crear el contexto para entrenamiento
context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32)
# Realizar predicción
forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False)
low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0)
plt.figure(figsize=(20, 10))
plt.clf()
# Determinar el rango de fechas para mostrar en el gráfico
context_days = min(10, train_data_points)
start_index = max(0, train_data_points - context_days)
end_index = min(train_data_points + prediction_days, total_points)
# Plotear datos históricos incluyendo datos reales después del entrenamiento
historical_dates = df['Date'][start_index:end_index]
historical_data = df[f'{ticker}_Close'][start_index:end_index].values
plt.plot(historical_dates,
historical_data,
color='blue',
linewidth=2,
label='Datos Reales')
# Crear fechas para la predicción considerando solo días hábiles
if train_data_points < total_points:
# Si hay más datos después del entrenamiento
prediction_start_date = df['Date'].iloc[train_data_points]
else:
# Si estamos en el último punto, generar fechas futuras
last_date = df['Date'].iloc[-1]
prediction_start_date = last_date + pd.Timedelta(days=1)
# Generar fechas de predicción solo en días hábiles
prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B')
# Plotear predicción
plt.plot(prediction_dates,
median,
color='black',
linewidth=2,
linestyle='-',
label='Predicción')
# Área de confianza
plt.fill_between(prediction_dates, low, high,
color='gray', alpha=0.2,
label='Intervalo de Confianza')
# Calcular métricas si hay datos reales para comparar
overlap_end_index = train_data_points + prediction_days
if overlap_end_index <= total_points:
real_future_dates = df['Date'][train_data_points:overlap_end_index]
real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
# Asegurar que las fechas de predicción y las reales coincidan
matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
matching_indices = matching_dates.index - train_data_points
plt.plot(matching_dates,
real_future_data[matching_indices],
color='red',
linewidth=2,
linestyle='--',
label='Datos Reales de Validación')
# Filtrar las predicciones que coinciden con las fechas reales
predicted_data = median[:len(matching_indices)]
mae = mean_absolute_error(real_future_data[matching_indices], predicted_data)
rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data))
mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100
plt.title(f"Predicción del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%",
fontsize=14, pad=20)
else:
plt.title(f"Predicción Futura del Precio de {ticker}",
fontsize=14, pad=20)
plt.legend(loc="upper left", fontsize=12)
plt.xlabel("Fecha", fontsize=12)
plt.ylabel("Precio", fontsize=12)
# Habilitar líneas de referencia diarias en el gráfico
plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5)
# Formatear el eje x para mostrar las fechas correctamente y agregar líneas de referencia diarias
ax = plt.gca()
locator = mdates.DayLocator()
formatter = mdates.DateFormatter('%Y-%m-%d')
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(formatter)
# Rotar las etiquetas de fecha
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
plt.tight_layout()
# Crear un archivo temporal para el CSV
temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
prediction_df = pd.DataFrame({
'Date': prediction_dates,
'Predicted_Price': median,
'Lower_Bound': low,
'Upper_Bound': high
})
# Agregar datos reales si están disponibles y coinciden con las fechas de predicción
if overlap_end_index <= total_points:
real_future_dates = df['Date'][train_data_points:overlap_end_index]
real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
# Guardar el DataFrame en el archivo temporal
prediction_df.to_csv(temp_csv.name, index=False)
temp_csv.close()
# Retornar el gráfico y la ruta del archivo CSV
return plt, temp_csv.name
except Exception as e:
print(f"Error: {str(e)}")
raise gr.Error(f"Error al procesar {ticker}: {str(e)}")
# Crear la interfaz de Gradio
with gr.Blocks() as demo:
gr.Markdown("# Aplicación de Predicción de Precios de Acciones")
with gr.Row():
with gr.Column(scale=1):
ticker = gr.Dropdown(
choices=get_popular_tickers(),
label="Selecciona el Símbolo de la Acción"
)
train_data_points = gr.Slider(
minimum=50,
maximum=5000,
value=1000,
step=1,
label="Número de Datos para Entrenamiento"
)
prediction_days = gr.Slider(
minimum=1,
maximum=60,
value=5,
step=1,
label="Número de Días a Predecir"
)
predict_btn = gr.Button("Predecir")
with gr.Column():
plot_output = gr.Plot(label="Gráfico de Predicción")
download_btn = gr.File(label="Descargar Predicciones")
def update_train_data_points(ticker):
# Actualizar el máximo de puntos de entrenamiento basándose en los datos disponibles
stock = yf.Ticker(ticker)
hist = stock.history(period="max")
total_points = len(hist)
# Actualizar el deslizador para reflejar el número total de puntos disponibles
return gr.Slider.update(maximum=total_points, value=min(1000, total_points))
ticker.change(
fn=update_train_data_points,
inputs=[ticker],
outputs=[train_data_points]
)
predict_btn.click(
fn=predict_stock,
inputs=[ticker, train_data_points, prediction_days],
outputs=[plot_output, download_btn]
)
demo.launch()