dschandra commited on
Commit
bfaa7c5
·
verified ·
1 Parent(s): 7a5c36d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -157
app.py CHANGED
@@ -1,173 +1,91 @@
1
- import os
 
2
  import yfinance as yf
3
  import pandas as pd
4
- import numpy as np
5
- import tensorflow as tf
6
- from tensorflow.keras.models import Sequential, load_model
7
- from tensorflow.keras.layers import LSTM, Dense, Dropout
8
  import matplotlib.pyplot as plt
9
- import gradio as gr
10
  from datetime import datetime
11
- import requests # To get the exchange rate
12
-
13
- # Disable GPU usage and oneDNN optimizations
14
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
15
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
16
 
17
- # Helper function to get the current USD to INR exchange rate
18
- def get_usd_to_inr_rate():
19
- try:
20
- response = requests.get('https://api.exchangerate-api.com/v4/latest/USD')
21
- data = response.json()
22
- return data['rates']['INR']
23
- except Exception as e:
24
- print(f"Error fetching exchange rate: {e}")
25
- return 82.0 # Use a fallback conversion rate (adjust if necessary)
26
-
27
- # Helper function to handle date adjustments and retries if data not found
28
- def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
29
- retries = 3 # Number of retries for fetching data
30
- while stock_data.empty and retries > 0:
31
- start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
32
- end_date = (datetime.strptime(end_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
33
- stock_data = yf.download(ticker, start=start_date, end=end_date)
34
- retries -= 1
35
- return stock_data, start_date, end_date
36
-
37
- # Define function to validate stock ticker and get stock data
38
  def get_stock_data(ticker, start_date, end_date):
39
- try:
40
- stock_data = yf.download(ticker, start=start_date, end=end_date)
41
- except Exception as e:
42
- print(f"Error fetching data: {e}")
43
- return None, None, None
44
-
45
- # If stock data is empty, attempt to adjust the date range
46
- if stock_data.empty:
47
- stock_data, adjusted_start, adjusted_end = adjust_date_range_if_needed(stock_data, ticker, start_date, end_date)
48
- if stock_data.empty:
49
- return None, None, None # If still empty after retries, return None
50
- return stock_data, adjusted_start, adjusted_end
51
- return stock_data, start_date, end_date
52
 
53
- # Preprocess the data for the LSTM model
54
- def preprocess_data(stock_data):
55
- close_prices = stock_data['Close'].values
56
- close_prices = close_prices.reshape(-1, 1)
57
-
58
- # Normalize the data
59
- from sklearn.preprocessing import MinMaxScaler
60
- scaler = MinMaxScaler(feature_range=(0, 1))
61
- scaled_data = scaler.fit_transform(close_prices)
 
62
 
63
- return scaled_data, scaler
 
 
 
64
 
65
- # Build the LSTM model
66
- def build_model():
67
- model = Sequential()
68
- model.add(LSTM(units=50, return_sequences=True, input_shape=(60, 1)))
69
- model.add(Dropout(0.2))
70
- model.add(LSTM(units=50, return_sequences=False))
71
- model.add(Dropout(0.2))
72
- model.add(Dense(units=1)) # Predicting the next closing price
73
- model.compile(optimizer='adam', loss='mean_squared_error')
74
- return model
75
 
76
- # Save and load models to avoid re-training
77
- def save_model(model, file_name):
78
- model.save(file_name)
 
 
 
 
 
 
 
 
 
 
79
 
80
- def load_trained_model(file_name):
81
- if os.path.exists(file_name):
82
- return load_model(file_name)
83
- else:
84
- return None
 
85
 
86
- # Train and make predictions
87
- def predict_stock(stock_data, scaler, model):
88
- last_60_days = stock_data[-60:]
89
- last_60_days_scaled = scaler.transform(last_60_days)
90
-
91
- # Prepare the input for prediction
92
- X_test = []
93
- X_test.append(last_60_days_scaled)
94
- X_test = np.array(X_test)
95
- X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
96
-
97
- # Predict
98
- predicted_price = model.predict(X_test)
99
- predicted_price = scaler.inverse_transform(predicted_price)
100
-
101
- return predicted_price
102
 
103
- # Main app function
104
- def stock_predictor(ticker, start_date, end_date):
105
- usd_to_inr = get_usd_to_inr_rate() # Get the USD to INR conversion rate
106
-
107
- # Get stock data
108
- stock_data, adjusted_start, adjusted_end = get_stock_data(ticker, start_date, end_date)
109
-
110
- if stock_data is None or stock_data.empty:
111
- return f"No data found for {ticker} in the selected or adjusted date range."
112
-
113
- # Preprocess the data
114
- scaled_data, scaler = preprocess_data(stock_data)
115
-
116
- # Try to load a pre-trained model
117
- model_file = f"{ticker}_model.h5"
118
- model = load_trained_model(model_file)
119
-
120
- if model is None:
121
- # Train the model if pre-trained model is not found
122
- model = build_model()
123
- X_train, y_train = [], []
124
- for i in range(60, len(scaled_data)):
125
- X_train.append(scaled_data[i-60:i, 0])
126
- y_train.append(scaled_data[i, 0])
127
- X_train, y_train = np.array(X_train), np.array(y_train)
128
- X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
129
-
130
- # Train the model (reduced epochs for faster processing)
131
- model.fit(X_train, y_train, epochs=2, batch_size=32) # Reduced epochs
132
-
133
- # Save the trained model
134
- save_model(model, model_file)
135
-
136
- # Predict stock price for tomorrow
137
- predicted_price = predict_stock(scaled_data, scaler, model)
138
-
139
- # Convert predicted price to INR
140
- predicted_price_inr = predicted_price[0][0] * usd_to_inr
141
-
142
- # Historical vs Predicted Graph
143
- plt.figure(figsize=(14, 7))
144
- plt.plot(stock_data['Close'], color="blue", label="Historical Prices (USD)")
145
- plt.scatter(len(stock_data), predicted_price[0], color="red", label="Predicted Price (USD)")
146
- plt.title(f"{ticker} Stock Price Prediction")
147
- plt.xlabel('Date')
148
- plt.ylabel('Price (USD)')
149
- plt.legend()
150
- plt.show()
151
 
152
- # Return the predicted price in INR
153
- return f"Predicted Stock Price for {ticker} tomorrow: ₹{predicted_price_inr:.2f} (INR)"
154
 
155
- # Gradio UI
156
- def build_ui():
157
- stock_tickers = ["AAPL", "TSLA", "AMZN", "MSFT", "GOOGL", "FB", "NFLX", "NVDA", "BABA", "JPM"]
158
-
159
- # Use Textbox for manual date input (format: YYYY-MM-DD)
160
- gr_interface = gr.Interface(
161
- fn=stock_predictor,
162
- inputs=[
163
- gr.Dropdown(stock_tickers, label="Stock Ticker"),
164
- gr.Textbox(label="Start Date (YYYY-MM-DD)", value="2022-01-01"), # Manual start date
165
- gr.Textbox(label="End Date (YYYY-MM-DD)", value=datetime.today().strftime("%Y-%m-%d")) # Manual end date
166
- ],
167
- outputs="text",
168
- title="Stock Price Prediction for Tomorrow"
169
- )
170
- gr_interface.launch()
 
 
 
171
 
172
- # Run the app
173
- build_ui()
 
1
+ # Import necessary libraries
2
+ import gradio as gr
3
  import yfinance as yf
4
  import pandas as pd
 
 
 
 
5
  import matplotlib.pyplot as plt
6
+ from neuralprophet import NeuralProphet
7
  from datetime import datetime
 
 
 
 
 
8
 
9
+ # Function to fetch historical stock data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def get_stock_data(ticker, start_date, end_date):
11
+ stock_data = yf.download(ticker, start=start_date, end=end_date)
12
+ stock_data.reset_index(inplace=True) # Reset index to use dates properly
13
+ return stock_data
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Function to preprocess data for NeuralProphet model
16
+ def prepare_data_for_neuralprophet(stock_data):
17
+ df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
18
+ return df
19
+
20
+ # Function to train NeuralProphet model and make predictions
21
+ def predict_stock(stock_data, period):
22
+ df = prepare_data_for_neuralprophet(stock_data)
23
+ model = NeuralProphet() # Initialize NeuralProphet model
24
+ model.fit(df) # Fit the model with the historical stock data
25
 
26
+ # Make future predictions
27
+ future = model.make_future_dataframe(df, periods=period) # Create a future dataframe for predictions
28
+ forecast = model.predict(future) # Predict future stock prices
29
+ return forecast[['ds', 'yhat1']]
30
 
31
+ # Function to get buy/sell recommendation based on percentage change
32
+ def get_recommendation(stock_data):
33
+ change_percent = ((stock_data['Close'].iloc[-1] - stock_data['Close'].iloc[0]) / stock_data['Close'].iloc[0]) * 100
34
+ if change_percent > 0:
35
+ return "Buy"
36
+ else:
37
+ return "Sell"
 
 
 
38
 
39
+ # Function to plot stock data
40
+ def plot_stock(stock_data, forecast):
41
+ plt.figure(figsize=(10, 5))
42
+ plt.plot(stock_data['Date'], stock_data['Close'], label='Historical Closing Price')
43
+ plt.plot(forecast['ds'], forecast['yhat1'], label='Predicted Closing Price')
44
+ plt.xlabel("Date")
45
+ plt.ylabel("Stock Price")
46
+ plt.title("Stock Price Prediction")
47
+ plt.legend()
48
+ plt.grid(True)
49
+ plt.savefig("stock_prediction_plot.png") # Save the plot as an image
50
+ plt.close()
51
+ return "stock_prediction_plot.png"
52
 
53
+ # Main function to handle user inputs and return results
54
+ def stock_prediction_app(ticker, start_date, end_date, prediction_period):
55
+ stock_data = get_stock_data(ticker, start_date, end_date) # Fetch historical stock data
56
+ forecast = predict_stock(stock_data, prediction_period) # Predict future prices
57
+ recommendation = get_recommendation(stock_data) # Get buy/sell recommendation
58
+ plot_file = plot_stock(stock_data, forecast) # Plot stock data and predictions
59
 
60
+ # Get the highest and lowest closing prices in the historical data
61
+ high = stock_data['Close'].max()
62
+ low = stock_data['Close'].min()
63
+ percentage_change = ((stock_data['Close'].iloc[-1] - stock_data['Close'].iloc[0]) / stock_data['Close'].iloc[0]) * 100
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ return high, low, percentage_change, recommendation, plot_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # Define the stock tickers for the dropdown
68
+ tickers = ['AAPL', 'GOOGL', 'AMZN', 'MSFT', 'TSLA', 'NFLX', 'NVDA', 'INTC', 'AMD', 'FB']
69
 
70
+ # Create the Gradio interface using the latest Gradio API
71
+ app = gr.Interface(
72
+ fn=stock_prediction_app,
73
+ inputs=[
74
+ gr.Dropdown(choices=tickers, label="Stock Ticker"),
75
+ gr.Textbox(label="Start Date (YYYY-MM-DD)"),
76
+ gr.Textbox(label="End Date (YYYY-MM-DD)"),
77
+ gr.Slider(1, 365, label="Prediction Period (Days)")
78
+ ],
79
+ outputs=[
80
+ gr.Textbox(label="Highest Value"),
81
+ gr.Textbox(label="Lowest Value"),
82
+ gr.Textbox(label="Percentage Change"),
83
+ gr.Textbox(label="Buy/Sell Recommendation"),
84
+ gr.Image(type="filepath", label="Stock Performance and Prediction Graph")
85
+ ],
86
+ title="AI-Powered Stock Prediction App",
87
+ description="Predict future stock prices, calculate highest and lowest prices, percentage change, and get a buy/sell recommendation based on historical data."
88
+ )
89
 
90
+ # Launch the Gradio app
91
+ app.launch()