abhi.2000 / app.py
Abhisesh7's picture
Update app.py
f2f0925 verified
raw
history blame
4.31 kB
import yfinance as yf
import pandas as pd
# Fetch stock data from Yahoo Finance
def get_stock_data(ticker, start_date, end_date):
stock_data = yf.download(ticker, start=start_date, end=end_date)
if stock_data.empty:
raise ValueError("No data found for the given ticker and date range.")
return stock_data
# Example usage
ticker = 'AAPL'
start_date = '2023-01-01'
end_date = '2024-01-01'
data = get_stock_data(ticker, start_date, end_date)
print(data.head())
from sklearn.preprocessing import MinMaxScaler
# Preprocess stock data
def preprocess_data(stock_data):
scaler = MinMaxScaler(feature_range=(0, 1)) # Normalizing the close prices
scaled_data = scaler.fit_transform(stock_data[['Close']].values)
return scaled_data, scaler
# Preprocess the stock data
scaled_data, scaler = preprocess_data(data)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
# Define and train an LSTM model
def create_lstm_model(input_shape):
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=input_shape))
model.add(LSTM(units=50))
model.add(Dense(units=1)) # Predicting stock price
model.compile(optimizer='adam', loss='mean_squared_error')
return model
# Prepare data for training the model
def train_model(stock_data, window_size=60):
X_train, y_train = [], []
for i in range(window_size, len(stock_data)):
X_train.append(stock_data[i-window_size:i, 0])
y_train.append(stock_data[i, 0])
X_train, y_train = np.array(X_train), np.array(y_train)
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))
model = create_lstm_model((X_train.shape[1], 1))
model.fit(X_train, y_train, epochs=10, batch_size=32)
return model
# Train the LSTM model
lstm_model = train_model(scaled_data)
import numpy as np
def predict_future_prices(model, scaler, recent_data, days_to_predict=90):
predictions = []
input_data = recent_data[-60:].reshape(1, 60, 1) # Using the last 60 days to predict
for _ in range(days_to_predict):
pred_price = model.predict(input_data)[0, 0]
predictions.append(pred_price)
input_data = np.append(input_data[:, 1:, :], [[pred_price]], axis=1)
# Inverse transform to get the original prices
predicted_prices = scaler.inverse_transform(np.array(predictions).reshape(-1, 1))
return predicted_prices
# Predict future stock prices
recent_data = scaled_data[-60:]
future_prices = predict_future_prices(lstm_model, scaler, recent_data)
import gradio as gr
import matplotlib.pyplot as plt
# Gradio function to predict stock prices and display results
def stock_prediction_app(ticker, start_date, end_date):
stock_data = get_stock_data(ticker, start_date, end_date)
scaled_data, scaler = preprocess_data(stock_data)
# Train the model on the selected stock data
model = train_model(scaled_data)
# Make predictions for the next 90 days
future_prices = predict_future_prices(model, scaler, scaled_data)
# Plot the historical and future stock prices
plt.figure(figsize=(10, 6))
plt.plot(stock_data.index, stock_data['Close'], label='Historical Prices')
future_dates = pd.date_range(end=stock_data.index[-1], periods=90)
plt.plot(future_dates, future_prices, label='Predicted Prices', linestyle='--')
plt.title(f'{ticker} Stock Price Prediction')
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.savefig('stock_prediction.png')
return f"The predicted stock price for the next 3 months is shown in the graph.", 'stock_prediction.png'
# Define Gradio interface
tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'FB', 'NFLX', 'NVDA', 'INTC', 'IBM']
app = gr.Blocks()
with app:
gr.Markdown("# Stock Buy/Sell Prediction App")
ticker = gr.Dropdown(tickers, label="Select Stock Ticker")
start_date = gr.Date(label="Start Date")
end_date = gr.Date(label="End Date")
predict_button = gr.Button("Predict")
output_text = gr.Textbox(label="Prediction Result")
output_image = gr.Image(label="Stock Price Graph")
predict_button.click(fn=stock_prediction_app, inputs=[ticker, start_date, end_date], outputs=[output_text, output_image])
app.launch()