Abhisesh7 commited on
Commit
2276894
·
verified ·
1 Parent(s): df90ba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -5,9 +5,10 @@ import tensorflow as tf
5
  from sklearn.preprocessing import MinMaxScaler
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
 
8
 
9
  # Define stock tickers for the dropdown
10
- tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'FB', 'NFLX', 'NVDA', 'INTC', 'IBM']
11
 
12
  # Function to fetch stock data and make predictions
13
  def stock_prediction_app(ticker, start_date, end_date):
@@ -38,7 +39,7 @@ def stock_prediction_app(ticker, start_date, end_date):
38
 
39
  # Define LSTM model
40
  lstm_model = tf.keras.Sequential([
41
- tf.keras.layers.LSTM(50, return_sequences=True, input_shape=(60, 1)),
42
  tf.keras.layers.LSTM(50, return_sequences=False),
43
  tf.keras.layers.Dense(25),
44
  tf.keras.layers.Dense(1)
@@ -48,7 +49,7 @@ def stock_prediction_app(ticker, start_date, end_date):
48
  lstm_model.compile(optimizer='adam', loss='mean_squared_error')
49
 
50
  # Train the model
51
- lstm_model.fit(X_train, y_train, batch_size=1, epochs=1)
52
 
53
  # Predict on the same data (just for demonstration)
54
  predictions = lstm_model.predict(X_train)
@@ -56,18 +57,20 @@ def stock_prediction_app(ticker, start_date, end_date):
56
 
57
  # Create a plot to show predictions
58
  plt.figure(figsize=(10, 5))
59
- plt.plot(df_close.values, label='Actual Stock Price')
60
- plt.plot(predictions, label='Predicted Stock Price')
61
  plt.title(f'{ticker} Stock Price Prediction')
62
  plt.xlabel('Days')
63
  plt.ylabel('Stock Price')
64
  plt.legend()
65
 
66
  # Save the plot to display in Gradio app
67
- plt.savefig('stock_prediction_plot.png')
 
 
68
 
69
  # Return a message and the path to the saved plot
70
- return f"Prediction complete for {ticker} from {start_date} to {end_date}", 'stock_prediction_plot.png'
71
 
72
  # Create the Gradio UI for the app
73
  app = gr.Blocks()
@@ -86,7 +89,7 @@ with app:
86
  predict_button = gr.Button("Predict")
87
 
88
  # Output fields for text and image
89
- output_text = gr.Textbox(label="Prediction Result")
90
  output_image = gr.Image(label="Stock Price Graph")
91
 
92
  # Set up button click event to run the prediction function
 
5
  from sklearn.preprocessing import MinMaxScaler
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
+ import os
9
 
10
  # Define stock tickers for the dropdown
11
+ tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'META', 'NFLX', 'NVDA', 'INTC', 'IBM']
12
 
13
  # Function to fetch stock data and make predictions
14
  def stock_prediction_app(ticker, start_date, end_date):
 
39
 
40
  # Define LSTM model
41
  lstm_model = tf.keras.Sequential([
42
+ tf.keras.layers.LSTM(50, return_sequences=True, input_shape=(X_train.shape[1], 1)),
43
  tf.keras.layers.LSTM(50, return_sequences=False),
44
  tf.keras.layers.Dense(25),
45
  tf.keras.layers.Dense(1)
 
49
  lstm_model.compile(optimizer='adam', loss='mean_squared_error')
50
 
51
  # Train the model
52
+ lstm_model.fit(X_train, y_train, batch_size=1, epochs=1, verbose=0)
53
 
54
  # Predict on the same data (just for demonstration)
55
  predictions = lstm_model.predict(X_train)
 
57
 
58
  # Create a plot to show predictions
59
  plt.figure(figsize=(10, 5))
60
+ plt.plot(df_close.values, label='Actual Stock Price', color='blue')
61
+ plt.plot(predictions, label='Predicted Stock Price', color='orange')
62
  plt.title(f'{ticker} Stock Price Prediction')
63
  plt.xlabel('Days')
64
  plt.ylabel('Stock Price')
65
  plt.legend()
66
 
67
  # Save the plot to display in Gradio app
68
+ plot_path = 'stock_prediction_plot.png'
69
+ plt.savefig(plot_path)
70
+ plt.close() # Close the plot to prevent display issues in some environments
71
 
72
  # Return a message and the path to the saved plot
73
+ return f"Prediction complete for {ticker} from {start_date} to {end_date}", plot_path
74
 
75
  # Create the Gradio UI for the app
76
  app = gr.Blocks()
 
89
  predict_button = gr.Button("Predict")
90
 
91
  # Output fields for text and image
92
+ output_text = gr.Textbox(label="Prediction Result", interactive=False)
93
  output_image = gr.Image(label="Stock Price Graph")
94
 
95
  # Set up button click event to run the prediction function