Update app.py
Browse files
app.py
CHANGED
@@ -8,18 +8,28 @@ from tensorflow.keras.layers import LSTM, Dense, Dropout
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import gradio as gr
|
10 |
from datetime import datetime
|
|
|
11 |
|
12 |
# Disable GPU usage and oneDNN optimizations
|
13 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
14 |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# Helper function to handle date adjustments and retries if data not found
|
17 |
def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
|
18 |
retries = 3 # Number of retries for fetching data
|
19 |
while stock_data.empty and retries > 0:
|
20 |
start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
|
21 |
end_date = (datetime.strptime(end_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
|
22 |
-
print(f"Retrying with adjusted dates: {start_date} to {end_date}") # Debugging output
|
23 |
stock_data = yf.download(ticker, start=start_date, end=end_date)
|
24 |
retries -= 1
|
25 |
return stock_data, start_date, end_date
|
@@ -28,14 +38,12 @@ def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
|
|
28 |
def get_stock_data(ticker, start_date, end_date):
|
29 |
try:
|
30 |
stock_data = yf.download(ticker, start=start_date, end=end_date)
|
31 |
-
print(f"Stock data downloaded: {stock_data.shape}") # Debugging output to check data download
|
32 |
except Exception as e:
|
33 |
-
print(f"Error fetching data: {e}")
|
34 |
return None, None, None
|
35 |
|
36 |
# If stock data is empty, attempt to adjust the date range
|
37 |
if stock_data.empty:
|
38 |
-
print("No data found for the original date range.") # Debugging output
|
39 |
stock_data, adjusted_start, adjusted_end = adjust_date_range_if_needed(stock_data, ticker, start_date, end_date)
|
40 |
if stock_data.empty:
|
41 |
return None, None, None # If still empty after retries, return None
|
@@ -44,7 +52,6 @@ def get_stock_data(ticker, start_date, end_date):
|
|
44 |
|
45 |
# Preprocess the data for the LSTM model
|
46 |
def preprocess_data(stock_data):
|
47 |
-
# Closing prices
|
48 |
close_prices = stock_data['Close'].values
|
49 |
close_prices = close_prices.reshape(-1, 1)
|
50 |
|
@@ -78,7 +85,6 @@ def load_trained_model(file_name):
|
|
78 |
|
79 |
# Train and make predictions
|
80 |
def predict_stock(stock_data, scaler, model):
|
81 |
-
# Get the last 60 days of stock data for prediction
|
82 |
last_60_days = stock_data[-60:]
|
83 |
last_60_days_scaled = scaler.transform(last_60_days)
|
84 |
|
@@ -92,14 +98,11 @@ def predict_stock(stock_data, scaler, model):
|
|
92 |
predicted_price = model.predict(X_test)
|
93 |
predicted_price = scaler.inverse_transform(predicted_price)
|
94 |
|
95 |
-
print(f"Predicted price: {predicted_price}") # Debugging output to verify predictions
|
96 |
return predicted_price
|
97 |
|
98 |
# Main app function
|
99 |
def stock_predictor(ticker, start_date, end_date):
|
100 |
-
#
|
101 |
-
if not ticker:
|
102 |
-
return f"Invalid stock ticker: {ticker}"
|
103 |
|
104 |
# Get stock data
|
105 |
stock_data, adjusted_start, adjusted_end = get_stock_data(ticker, start_date, end_date)
|
@@ -116,7 +119,6 @@ def stock_predictor(ticker, start_date, end_date):
|
|
116 |
|
117 |
if model is None:
|
118 |
# Train the model if pre-trained model is not found
|
119 |
-
print("Training the model...") # Debugging output for training
|
120 |
model = build_model()
|
121 |
X_train, y_train = [], []
|
122 |
for i in range(60, len(scaled_data)):
|
@@ -134,17 +136,21 @@ def stock_predictor(ticker, start_date, end_date):
|
|
134 |
# Predict stock price for tomorrow
|
135 |
predicted_price = predict_stock(scaled_data, scaler, model)
|
136 |
|
137 |
-
#
|
|
|
|
|
|
|
138 |
plt.figure(figsize=(14, 7))
|
139 |
-
plt.plot(stock_data['Close'], color="blue", label="Historical Prices")
|
140 |
-
plt.scatter(len(stock_data), predicted_price, color="red", label="Predicted Price
|
141 |
plt.title(f"{ticker} Stock Price Prediction")
|
142 |
plt.xlabel('Date')
|
143 |
-
plt.ylabel('Price')
|
144 |
plt.legend()
|
145 |
plt.show()
|
146 |
|
147 |
-
|
|
|
148 |
|
149 |
# Gradio UI
|
150 |
def build_ui():
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import gradio as gr
|
10 |
from datetime import datetime
|
11 |
+
import requests # To get the exchange rate
|
12 |
|
13 |
# Disable GPU usage and oneDNN optimizations
|
14 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
15 |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
16 |
|
17 |
+
# Helper function to get the current USD to INR exchange rate
|
18 |
+
def get_usd_to_inr_rate():
|
19 |
+
try:
|
20 |
+
response = requests.get('https://api.exchangerate-api.com/v4/latest/USD')
|
21 |
+
data = response.json()
|
22 |
+
return data['rates']['INR']
|
23 |
+
except Exception as e:
|
24 |
+
print(f"Error fetching exchange rate: {e}")
|
25 |
+
return 82.0 # Use a fallback conversion rate (adjust if necessary)
|
26 |
+
|
27 |
# Helper function to handle date adjustments and retries if data not found
|
28 |
def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
|
29 |
retries = 3 # Number of retries for fetching data
|
30 |
while stock_data.empty and retries > 0:
|
31 |
start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
|
32 |
end_date = (datetime.strptime(end_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
|
|
|
33 |
stock_data = yf.download(ticker, start=start_date, end=end_date)
|
34 |
retries -= 1
|
35 |
return stock_data, start_date, end_date
|
|
|
38 |
def get_stock_data(ticker, start_date, end_date):
|
39 |
try:
|
40 |
stock_data = yf.download(ticker, start=start_date, end=end_date)
|
|
|
41 |
except Exception as e:
|
42 |
+
print(f"Error fetching data: {e}")
|
43 |
return None, None, None
|
44 |
|
45 |
# If stock data is empty, attempt to adjust the date range
|
46 |
if stock_data.empty:
|
|
|
47 |
stock_data, adjusted_start, adjusted_end = adjust_date_range_if_needed(stock_data, ticker, start_date, end_date)
|
48 |
if stock_data.empty:
|
49 |
return None, None, None # If still empty after retries, return None
|
|
|
52 |
|
53 |
# Preprocess the data for the LSTM model
|
54 |
def preprocess_data(stock_data):
|
|
|
55 |
close_prices = stock_data['Close'].values
|
56 |
close_prices = close_prices.reshape(-1, 1)
|
57 |
|
|
|
85 |
|
86 |
# Train and make predictions
|
87 |
def predict_stock(stock_data, scaler, model):
|
|
|
88 |
last_60_days = stock_data[-60:]
|
89 |
last_60_days_scaled = scaler.transform(last_60_days)
|
90 |
|
|
|
98 |
predicted_price = model.predict(X_test)
|
99 |
predicted_price = scaler.inverse_transform(predicted_price)
|
100 |
|
|
|
101 |
return predicted_price
|
102 |
|
103 |
# Main app function
|
104 |
def stock_predictor(ticker, start_date, end_date):
|
105 |
+
usd_to_inr = get_usd_to_inr_rate() # Get the USD to INR conversion rate
|
|
|
|
|
106 |
|
107 |
# Get stock data
|
108 |
stock_data, adjusted_start, adjusted_end = get_stock_data(ticker, start_date, end_date)
|
|
|
119 |
|
120 |
if model is None:
|
121 |
# Train the model if pre-trained model is not found
|
|
|
122 |
model = build_model()
|
123 |
X_train, y_train = [], []
|
124 |
for i in range(60, len(scaled_data)):
|
|
|
136 |
# Predict stock price for tomorrow
|
137 |
predicted_price = predict_stock(scaled_data, scaler, model)
|
138 |
|
139 |
+
# Convert predicted price to INR
|
140 |
+
predicted_price_inr = predicted_price[0][0] * usd_to_inr
|
141 |
+
|
142 |
+
# Historical vs Predicted Graph
|
143 |
plt.figure(figsize=(14, 7))
|
144 |
+
plt.plot(stock_data['Close'], color="blue", label="Historical Prices (USD)")
|
145 |
+
plt.scatter(len(stock_data), predicted_price[0], color="red", label="Predicted Price (USD)")
|
146 |
plt.title(f"{ticker} Stock Price Prediction")
|
147 |
plt.xlabel('Date')
|
148 |
+
plt.ylabel('Price (USD)')
|
149 |
plt.legend()
|
150 |
plt.show()
|
151 |
|
152 |
+
# Return the predicted price in INR
|
153 |
+
return f"Predicted Stock Price for {ticker} tomorrow: ₹{predicted_price_inr:.2f} (INR)"
|
154 |
|
155 |
# Gradio UI
|
156 |
def build_ui():
|