sebasfb99 commited on
Commit
b996b2e
·
verified ·
1 Parent(s): e3003c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -7,7 +7,6 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
  import matplotlib.dates as mdates
9
  from sklearn.metrics import mean_absolute_error, mean_squared_error
10
-
11
  def get_popular_tickers():
12
  return [
13
  "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
@@ -16,10 +15,6 @@ def get_popular_tickers():
16
 
17
  def predict_stock(ticker, train_data_points, prediction_days):
18
  try:
19
- # Asegurar que los parámetros sean enteros
20
- train_data_points = int(train_data_points)
21
- prediction_days = int(prediction_days)
22
-
23
  # Configurar el pipeline
24
  pipeline = ChronosPipeline.from_pretrained(
25
  "amazon/chronos-t5-mini",
@@ -89,6 +84,7 @@ def predict_stock(ticker, train_data_points, prediction_days):
89
 
90
  # Calcular métricas si hay datos reales para comparar
91
  overlap_end_index = train_data_points + prediction_days
 
92
  if overlap_end_index <= total_points:
93
  real_future_dates = df['Date'][train_data_points:overlap_end_index]
94
  real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
@@ -149,8 +145,10 @@ def predict_stock(ticker, train_data_points, prediction_days):
149
  prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
150
  prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
151
 
152
- # Retornar el gráfico y los datos de predicción
153
- return plt, gr.File.update(value=prediction_df.to_csv(index=False), filename=f"predictions_{ticker}.csv")
 
 
154
 
155
  except Exception as e:
156
  print(f"Error: {str(e)}")
@@ -168,7 +166,7 @@ with gr.Blocks() as demo:
168
  )
169
  train_data_points = gr.Slider(
170
  minimum=50,
171
- maximum=5000,
172
  value=1000,
173
  step=1,
174
  label="Número de Datos para Entrenamiento"
@@ -181,7 +179,7 @@ with gr.Blocks() as demo:
181
  label="Número de Días a Predecir"
182
  )
183
  predict_btn = gr.Button("Predecir")
184
-
185
  with gr.Column():
186
  plot_output = gr.Plot(label="Gráfico de Predicción")
187
  download_btn = gr.File(label="Descargar Predicciones")
@@ -192,7 +190,7 @@ with gr.Blocks() as demo:
192
  hist = stock.history(period="max")
193
  total_points = len(hist)
194
  # Actualizar el deslizador para reflejar el número total de puntos disponibles
195
- return gr.Slider.update(maximum=total_points, value=min(1000, total_points))
196
 
197
  ticker.change(
198
  fn=update_train_data_points,
@@ -206,4 +204,4 @@ with gr.Blocks() as demo:
206
  outputs=[plot_output, download_btn]
207
  )
208
 
209
- demo.launch()
 
7
  import matplotlib.pyplot as plt
8
  import matplotlib.dates as mdates
9
  from sklearn.metrics import mean_absolute_error, mean_squared_error
 
10
  def get_popular_tickers():
11
  return [
12
  "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
 
15
 
16
  def predict_stock(ticker, train_data_points, prediction_days):
17
  try:
 
 
 
 
18
  # Configurar el pipeline
19
  pipeline = ChronosPipeline.from_pretrained(
20
  "amazon/chronos-t5-mini",
 
84
 
85
  # Calcular métricas si hay datos reales para comparar
86
  overlap_end_index = train_data_points + prediction_days
87
+ validation_data = None
88
  if overlap_end_index <= total_points:
89
  real_future_dates = df['Date'][train_data_points:overlap_end_index]
90
  real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
 
145
  prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
146
  prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
147
 
148
+ csv_path = f"predictions_{ticker}.csv"
149
+ prediction_df.to_csv(csv_path, index=False)
150
+
151
+ return plt, csv_path
152
 
153
  except Exception as e:
154
  print(f"Error: {str(e)}")
 
166
  )
167
  train_data_points = gr.Slider(
168
  minimum=50,
169
+ maximum=5000, # Puedes ajustar este valor si lo deseas
170
  value=1000,
171
  step=1,
172
  label="Número de Datos para Entrenamiento"
 
179
  label="Número de Días a Predecir"
180
  )
181
  predict_btn = gr.Button("Predecir")
182
+
183
  with gr.Column():
184
  plot_output = gr.Plot(label="Gráfico de Predicción")
185
  download_btn = gr.File(label="Descargar Predicciones")
 
190
  hist = stock.history(period="max")
191
  total_points = len(hist)
192
  # Actualizar el deslizador para reflejar el número total de puntos disponibles
193
+ return gr.update(maximum=total_points, value=min(1000, total_points))
194
 
195
  ticker.change(
196
  fn=update_train_data_points,
 
204
  outputs=[plot_output, download_btn]
205
  )
206
 
207
+ demo.launch(debug=True)