sebasfb99 commited on
Commit
e3003c7
·
verified ·
1 Parent(s): 7526843

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from chronos import ChronosPipeline
4
+ import yfinance as yf
5
+ import pandas as pd
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.dates as mdates
9
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
10
+
11
+ def get_popular_tickers():
12
+ return [
13
+ "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
14
+ "JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC"
15
+ ]
16
+
17
+ def predict_stock(ticker, train_data_points, prediction_days):
18
+ try:
19
+ # Asegurar que los parámetros sean enteros
20
+ train_data_points = int(train_data_points)
21
+ prediction_days = int(prediction_days)
22
+
23
+ # Configurar el pipeline
24
+ pipeline = ChronosPipeline.from_pretrained(
25
+ "amazon/chronos-t5-mini",
26
+ device_map="cpu",
27
+ torch_dtype=torch.float32
28
+ )
29
+
30
+ # Obtener la cantidad máxima de datos disponibles
31
+ stock = yf.Ticker(ticker)
32
+ hist = stock.history(period="max")
33
+ stock_prices = hist[['Close']].reset_index()
34
+ df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'})
35
+
36
+ total_points = len(df)
37
+
38
+ # Asegurar que el número de datos de entrenamiento no exceda el total disponible
39
+ train_data_points = min(train_data_points, total_points)
40
+
41
+ # Crear el contexto para entrenamiento
42
+ context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32)
43
+
44
+ # Realizar predicción
45
+ forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False)
46
+ low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0)
47
+
48
+ plt.figure(figsize=(20, 10))
49
+ plt.clf()
50
+
51
+ # Determinar el rango de fechas para mostrar en el gráfico
52
+ context_days = min(10, train_data_points)
53
+ start_index = max(0, train_data_points - context_days)
54
+ end_index = min(train_data_points + prediction_days, total_points)
55
+
56
+ # Plotear datos históricos incluyendo datos reales después del entrenamiento
57
+ historical_dates = df['Date'][start_index:end_index]
58
+ historical_data = df[f'{ticker}_Close'][start_index:end_index].values
59
+ plt.plot(historical_dates,
60
+ historical_data,
61
+ color='blue',
62
+ linewidth=2,
63
+ label='Datos Reales')
64
+
65
+ # Crear fechas para la predicción considerando solo días hábiles
66
+ if train_data_points < total_points:
67
+ # Si hay más datos después del entrenamiento
68
+ prediction_start_date = df['Date'].iloc[train_data_points]
69
+ else:
70
+ # Si estamos en el último punto, generar fechas futuras
71
+ last_date = df['Date'].iloc[-1]
72
+ prediction_start_date = last_date + pd.Timedelta(days=1)
73
+
74
+ # Generar fechas de predicción solo en días hábiles
75
+ prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B')
76
+
77
+ # Plotear predicción
78
+ plt.plot(prediction_dates,
79
+ median,
80
+ color='black',
81
+ linewidth=2,
82
+ linestyle='-',
83
+ label='Predicción')
84
+
85
+ # Área de confianza
86
+ plt.fill_between(prediction_dates, low, high,
87
+ color='gray', alpha=0.2,
88
+ label='Intervalo de Confianza')
89
+
90
+ # Calcular métricas si hay datos reales para comparar
91
+ overlap_end_index = train_data_points + prediction_days
92
+ if overlap_end_index <= total_points:
93
+ real_future_dates = df['Date'][train_data_points:overlap_end_index]
94
+ real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
95
+
96
+ # Asegurar que las fechas de predicción y las reales coincidan
97
+ matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
98
+ matching_indices = matching_dates.index - train_data_points
99
+ plt.plot(matching_dates,
100
+ real_future_data[matching_indices],
101
+ color='red',
102
+ linewidth=2,
103
+ linestyle='--',
104
+ label='Datos Reales de Validación')
105
+
106
+ # Filtrar las predicciones que coinciden con las fechas reales
107
+ predicted_data = median[:len(matching_indices)]
108
+ mae = mean_absolute_error(real_future_data[matching_indices], predicted_data)
109
+ rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data))
110
+ mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100
111
+ plt.title(f"Predicción del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%",
112
+ fontsize=14, pad=20)
113
+ else:
114
+ plt.title(f"Predicción Futura del Precio de {ticker}",
115
+ fontsize=14, pad=20)
116
+
117
+ plt.legend(loc="upper left", fontsize=12)
118
+ plt.xlabel("Fecha", fontsize=12)
119
+ plt.ylabel("Precio", fontsize=12)
120
+
121
+ # Habilitar líneas de referencia diarias en el gráfico
122
+ plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5)
123
+
124
+ # Formatear el eje x para mostrar las fechas correctamente y agregar líneas de referencia diarias
125
+ ax = plt.gca()
126
+ locator = mdates.DayLocator()
127
+ formatter = mdates.DateFormatter('%Y-%m-%d')
128
+ ax.xaxis.set_major_locator(locator)
129
+ ax.xaxis.set_major_formatter(formatter)
130
+
131
+ # Rotar las etiquetas de fecha
132
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
133
+
134
+ plt.tight_layout()
135
+
136
+ # Crear DataFrame para descarga
137
+ prediction_df = pd.DataFrame({
138
+ 'Date': prediction_dates,
139
+ 'Predicted_Price': median,
140
+ 'Lower_Bound': low,
141
+ 'Upper_Bound': high
142
+ })
143
+
144
+ # Agregar datos reales si están disponibles y coinciden con las fechas de predicción
145
+ if overlap_end_index <= total_points:
146
+ real_future_dates = df['Date'][train_data_points:overlap_end_index]
147
+ real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
148
+ matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
149
+ prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
150
+ prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
151
+
152
+ # Retornar el gráfico y los datos de predicción
153
+ return plt, gr.File.update(value=prediction_df.to_csv(index=False), filename=f"predictions_{ticker}.csv")
154
+
155
+ except Exception as e:
156
+ print(f"Error: {str(e)}")
157
+ raise gr.Error(f"Error al procesar {ticker}: {str(e)}")
158
+
159
+ # Crear la interfaz de Gradio
160
+ with gr.Blocks() as demo:
161
+ gr.Markdown("# Aplicación de Predicción de Precios de Acciones")
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ ticker = gr.Dropdown(
166
+ choices=get_popular_tickers(),
167
+ label="Selecciona el Símbolo de la Acción"
168
+ )
169
+ train_data_points = gr.Slider(
170
+ minimum=50,
171
+ maximum=5000,
172
+ value=1000,
173
+ step=1,
174
+ label="Número de Datos para Entrenamiento"
175
+ )
176
+ prediction_days = gr.Slider(
177
+ minimum=1,
178
+ maximum=60,
179
+ value=5,
180
+ step=1,
181
+ label="Número de Días a Predecir"
182
+ )
183
+ predict_btn = gr.Button("Predecir")
184
+
185
+ with gr.Column():
186
+ plot_output = gr.Plot(label="Gráfico de Predicción")
187
+ download_btn = gr.File(label="Descargar Predicciones")
188
+
189
+ def update_train_data_points(ticker):
190
+ # Actualizar el máximo de puntos de entrenamiento basándose en los datos disponibles
191
+ stock = yf.Ticker(ticker)
192
+ hist = stock.history(period="max")
193
+ total_points = len(hist)
194
+ # Actualizar el deslizador para reflejar el número total de puntos disponibles
195
+ return gr.Slider.update(maximum=total_points, value=min(1000, total_points))
196
+
197
+ ticker.change(
198
+ fn=update_train_data_points,
199
+ inputs=[ticker],
200
+ outputs=[train_data_points]
201
+ )
202
+
203
+ predict_btn.click(
204
+ fn=predict_stock,
205
+ inputs=[ticker, train_data_points, prediction_days],
206
+ outputs=[plot_output, download_btn]
207
+ )
208
+
209
+ demo.launch()