abhi.2000 / app.py
Abhisesh7's picture
Update app.py
6e4a9f8 verified
raw
history blame
3.51 kB
import yfinance as yf
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import gradio as gr
import matplotlib.pyplot as plt
# Define stock tickers for the dropdown
tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'FB', 'NFLX', 'NVDA', 'INTC', 'IBM']
# Function to fetch stock data and make predictions
def stock_prediction_app(ticker, start_date, end_date):
# Fetch historical stock data from Yahoo Finance
stock_data = yf.download(ticker, start=start_date, end=end_date)
# Check if data is fetched correctly
if stock_data.empty:
return "No data available for the selected date range.", None
# Prepare the data for LSTM model
df_close = stock_data[['Close']] # Use only the 'Close' column for prediction
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(df_close)
# Create datasets for training the LSTM model
def create_dataset(data, time_step=60):
X_train, y_train = [], []
for i in range(len(data)-time_step-1):
X_train.append(data[i:(i+time_step), 0])
y_train.append(data[i + time_step, 0])
return np.array(X_train), np.array(y_train)
X_train, y_train = create_dataset(scaled_data)
# Reshape the data for LSTM [samples, time steps, features]
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
# Define LSTM model
lstm_model = tf.keras.Sequential([
tf.keras.layers.LSTM(50, return_sequences=True, input_shape=(60, 1)),
tf.keras.layers.LSTM(50, return_sequences=False),
tf.keras.layers.Dense(25),
tf.keras.layers.Dense(1)
])
# Compile the model
lstm_model.compile(optimizer='adam', loss='mean_squared_error')
# Train the model
lstm_model.fit(X_train, y_train, batch_size=1, epochs=1)
# Predict on the same data (just for demonstration)
predictions = lstm_model.predict(X_train)
predictions = scaler.inverse_transform(predictions) # Convert back to original scale
# Create a plot to show predictions
plt.figure(figsize=(10, 5))
plt.plot(df_close.values, label='Actual Stock Price')
plt.plot(predictions, label='Predicted Stock Price')
plt.title(f'{ticker} Stock Price Prediction')
plt.xlabel('Days')
plt.ylabel('Stock Price')
plt.legend()
# Save the plot to display in Gradio app
plt.savefig('stock_prediction_plot.png')
# Return a message and the path to the saved plot
return f"Prediction complete for {ticker} from {start_date} to {end_date}", 'stock_prediction_plot.png'
# Create the Gradio UI for the app
app = gr.Blocks()
with app:
gr.Markdown("# Stock Buy/Sell Prediction App")
# Dropdown for stock tickers
ticker = gr.Dropdown(tickers, label="Select Stock Ticker")
# Textboxes for manual date input
start_date = gr.Textbox(label="Start Date (YYYY-MM-DD)")
end_date = gr.Textbox(label="End Date (YYYY-MM-DD)")
# Button to trigger the prediction
predict_button = gr.Button("Predict")
# Output fields for text and image
output_text = gr.Textbox(label="Prediction Result")
output_image = gr.Image(label="Stock Price Graph")
# Set up button click event to run the prediction function
predict_button.click(fn=stock_prediction_app, inputs=[ticker, start_date, end_date], outputs=[output_text, output_image])
# Launch the Gradio app
app.launch()