Update app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,10 @@ import tensorflow as tf
|
|
5 |
from sklearn.preprocessing import MinMaxScaler
|
6 |
import gradio as gr
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
|
9 |
# Define stock tickers for the dropdown
|
10 |
-
tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', '
|
11 |
|
12 |
# Function to fetch stock data and make predictions
|
13 |
def stock_prediction_app(ticker, start_date, end_date):
|
@@ -38,7 +39,7 @@ def stock_prediction_app(ticker, start_date, end_date):
|
|
38 |
|
39 |
# Define LSTM model
|
40 |
lstm_model = tf.keras.Sequential([
|
41 |
-
tf.keras.layers.LSTM(50, return_sequences=True, input_shape=(
|
42 |
tf.keras.layers.LSTM(50, return_sequences=False),
|
43 |
tf.keras.layers.Dense(25),
|
44 |
tf.keras.layers.Dense(1)
|
@@ -48,7 +49,7 @@ def stock_prediction_app(ticker, start_date, end_date):
|
|
48 |
lstm_model.compile(optimizer='adam', loss='mean_squared_error')
|
49 |
|
50 |
# Train the model
|
51 |
-
lstm_model.fit(X_train, y_train, batch_size=1, epochs=1)
|
52 |
|
53 |
# Predict on the same data (just for demonstration)
|
54 |
predictions = lstm_model.predict(X_train)
|
@@ -56,18 +57,20 @@ def stock_prediction_app(ticker, start_date, end_date):
|
|
56 |
|
57 |
# Create a plot to show predictions
|
58 |
plt.figure(figsize=(10, 5))
|
59 |
-
plt.plot(df_close.values, label='Actual Stock Price')
|
60 |
-
plt.plot(predictions, label='Predicted Stock Price')
|
61 |
plt.title(f'{ticker} Stock Price Prediction')
|
62 |
plt.xlabel('Days')
|
63 |
plt.ylabel('Stock Price')
|
64 |
plt.legend()
|
65 |
|
66 |
# Save the plot to display in Gradio app
|
67 |
-
|
|
|
|
|
68 |
|
69 |
# Return a message and the path to the saved plot
|
70 |
-
return f"Prediction complete for {ticker} from {start_date} to {end_date}",
|
71 |
|
72 |
# Create the Gradio UI for the app
|
73 |
app = gr.Blocks()
|
@@ -86,7 +89,7 @@ with app:
|
|
86 |
predict_button = gr.Button("Predict")
|
87 |
|
88 |
# Output fields for text and image
|
89 |
-
output_text = gr.Textbox(label="Prediction Result")
|
90 |
output_image = gr.Image(label="Stock Price Graph")
|
91 |
|
92 |
# Set up button click event to run the prediction function
|
|
|
5 |
from sklearn.preprocessing import MinMaxScaler
|
6 |
import gradio as gr
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
import os
|
9 |
|
10 |
# Define stock tickers for the dropdown
|
11 |
+
tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'META', 'NFLX', 'NVDA', 'INTC', 'IBM']
|
12 |
|
13 |
# Function to fetch stock data and make predictions
|
14 |
def stock_prediction_app(ticker, start_date, end_date):
|
|
|
39 |
|
40 |
# Define LSTM model
|
41 |
lstm_model = tf.keras.Sequential([
|
42 |
+
tf.keras.layers.LSTM(50, return_sequences=True, input_shape=(X_train.shape[1], 1)),
|
43 |
tf.keras.layers.LSTM(50, return_sequences=False),
|
44 |
tf.keras.layers.Dense(25),
|
45 |
tf.keras.layers.Dense(1)
|
|
|
49 |
lstm_model.compile(optimizer='adam', loss='mean_squared_error')
|
50 |
|
51 |
# Train the model
|
52 |
+
lstm_model.fit(X_train, y_train, batch_size=1, epochs=1, verbose=0)
|
53 |
|
54 |
# Predict on the same data (just for demonstration)
|
55 |
predictions = lstm_model.predict(X_train)
|
|
|
57 |
|
58 |
# Create a plot to show predictions
|
59 |
plt.figure(figsize=(10, 5))
|
60 |
+
plt.plot(df_close.values, label='Actual Stock Price', color='blue')
|
61 |
+
plt.plot(predictions, label='Predicted Stock Price', color='orange')
|
62 |
plt.title(f'{ticker} Stock Price Prediction')
|
63 |
plt.xlabel('Days')
|
64 |
plt.ylabel('Stock Price')
|
65 |
plt.legend()
|
66 |
|
67 |
# Save the plot to display in Gradio app
|
68 |
+
plot_path = 'stock_prediction_plot.png'
|
69 |
+
plt.savefig(plot_path)
|
70 |
+
plt.close() # Close the plot to prevent display issues in some environments
|
71 |
|
72 |
# Return a message and the path to the saved plot
|
73 |
+
return f"Prediction complete for {ticker} from {start_date} to {end_date}", plot_path
|
74 |
|
75 |
# Create the Gradio UI for the app
|
76 |
app = gr.Blocks()
|
|
|
89 |
predict_button = gr.Button("Predict")
|
90 |
|
91 |
# Output fields for text and image
|
92 |
+
output_text = gr.Textbox(label="Prediction Result", interactive=False)
|
93 |
output_image = gr.Image(label="Stock Price Graph")
|
94 |
|
95 |
# Set up button click event to run the prediction function
|