dschandra commited on
Commit
7a5c36d
·
verified ·
1 Parent(s): 0d0e41d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -8,18 +8,28 @@ from tensorflow.keras.layers import LSTM, Dense, Dropout
8
  import matplotlib.pyplot as plt
9
  import gradio as gr
10
  from datetime import datetime
 
11
 
12
  # Disable GPU usage and oneDNN optimizations
13
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
14
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
15
 
 
 
 
 
 
 
 
 
 
 
16
  # Helper function to handle date adjustments and retries if data not found
17
  def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
18
  retries = 3 # Number of retries for fetching data
19
  while stock_data.empty and retries > 0:
20
  start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
21
  end_date = (datetime.strptime(end_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
22
- print(f"Retrying with adjusted dates: {start_date} to {end_date}") # Debugging output
23
  stock_data = yf.download(ticker, start=start_date, end=end_date)
24
  retries -= 1
25
  return stock_data, start_date, end_date
@@ -28,14 +38,12 @@ def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
28
  def get_stock_data(ticker, start_date, end_date):
29
  try:
30
  stock_data = yf.download(ticker, start=start_date, end=end_date)
31
- print(f"Stock data downloaded: {stock_data.shape}") # Debugging output to check data download
32
  except Exception as e:
33
- print(f"Error fetching data: {e}") # Debugging output for data download error
34
  return None, None, None
35
 
36
  # If stock data is empty, attempt to adjust the date range
37
  if stock_data.empty:
38
- print("No data found for the original date range.") # Debugging output
39
  stock_data, adjusted_start, adjusted_end = adjust_date_range_if_needed(stock_data, ticker, start_date, end_date)
40
  if stock_data.empty:
41
  return None, None, None # If still empty after retries, return None
@@ -44,7 +52,6 @@ def get_stock_data(ticker, start_date, end_date):
44
 
45
  # Preprocess the data for the LSTM model
46
  def preprocess_data(stock_data):
47
- # Closing prices
48
  close_prices = stock_data['Close'].values
49
  close_prices = close_prices.reshape(-1, 1)
50
 
@@ -78,7 +85,6 @@ def load_trained_model(file_name):
78
 
79
  # Train and make predictions
80
  def predict_stock(stock_data, scaler, model):
81
- # Get the last 60 days of stock data for prediction
82
  last_60_days = stock_data[-60:]
83
  last_60_days_scaled = scaler.transform(last_60_days)
84
 
@@ -92,14 +98,11 @@ def predict_stock(stock_data, scaler, model):
92
  predicted_price = model.predict(X_test)
93
  predicted_price = scaler.inverse_transform(predicted_price)
94
 
95
- print(f"Predicted price: {predicted_price}") # Debugging output to verify predictions
96
  return predicted_price
97
 
98
  # Main app function
99
  def stock_predictor(ticker, start_date, end_date):
100
- # Validate if the ticker is available on Yahoo Finance
101
- if not ticker:
102
- return f"Invalid stock ticker: {ticker}"
103
 
104
  # Get stock data
105
  stock_data, adjusted_start, adjusted_end = get_stock_data(ticker, start_date, end_date)
@@ -116,7 +119,6 @@ def stock_predictor(ticker, start_date, end_date):
116
 
117
  if model is None:
118
  # Train the model if pre-trained model is not found
119
- print("Training the model...") # Debugging output for training
120
  model = build_model()
121
  X_train, y_train = [], []
122
  for i in range(60, len(scaled_data)):
@@ -134,17 +136,21 @@ def stock_predictor(ticker, start_date, end_date):
134
  # Predict stock price for tomorrow
135
  predicted_price = predict_stock(scaled_data, scaler, model)
136
 
137
- # Historical vs Predicted
 
 
 
138
  plt.figure(figsize=(14, 7))
139
- plt.plot(stock_data['Close'], color="blue", label="Historical Prices")
140
- plt.scatter(len(stock_data), predicted_price, color="red", label="Predicted Price for Tomorrow")
141
  plt.title(f"{ticker} Stock Price Prediction")
142
  plt.xlabel('Date')
143
- plt.ylabel('Price')
144
  plt.legend()
145
  plt.show()
146
 
147
- return f"Predicted Stock Price for {ticker} tomorrow: ${predicted_price[0][0]:.2f}"
 
148
 
149
  # Gradio UI
150
  def build_ui():
 
8
  import matplotlib.pyplot as plt
9
  import gradio as gr
10
  from datetime import datetime
11
+ import requests # To get the exchange rate
12
 
13
  # Disable GPU usage and oneDNN optimizations
14
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
15
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
16
 
17
+ # Helper function to get the current USD to INR exchange rate
18
+ def get_usd_to_inr_rate():
19
+ try:
20
+ response = requests.get('https://api.exchangerate-api.com/v4/latest/USD')
21
+ data = response.json()
22
+ return data['rates']['INR']
23
+ except Exception as e:
24
+ print(f"Error fetching exchange rate: {e}")
25
+ return 82.0 # Use a fallback conversion rate (adjust if necessary)
26
+
27
  # Helper function to handle date adjustments and retries if data not found
28
  def adjust_date_range_if_needed(stock_data, ticker, start_date, end_date):
29
  retries = 3 # Number of retries for fetching data
30
  while stock_data.empty and retries > 0:
31
  start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
32
  end_date = (datetime.strptime(end_date, '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
 
33
  stock_data = yf.download(ticker, start=start_date, end=end_date)
34
  retries -= 1
35
  return stock_data, start_date, end_date
 
38
  def get_stock_data(ticker, start_date, end_date):
39
  try:
40
  stock_data = yf.download(ticker, start=start_date, end=end_date)
 
41
  except Exception as e:
42
+ print(f"Error fetching data: {e}")
43
  return None, None, None
44
 
45
  # If stock data is empty, attempt to adjust the date range
46
  if stock_data.empty:
 
47
  stock_data, adjusted_start, adjusted_end = adjust_date_range_if_needed(stock_data, ticker, start_date, end_date)
48
  if stock_data.empty:
49
  return None, None, None # If still empty after retries, return None
 
52
 
53
  # Preprocess the data for the LSTM model
54
  def preprocess_data(stock_data):
 
55
  close_prices = stock_data['Close'].values
56
  close_prices = close_prices.reshape(-1, 1)
57
 
 
85
 
86
  # Train and make predictions
87
  def predict_stock(stock_data, scaler, model):
 
88
  last_60_days = stock_data[-60:]
89
  last_60_days_scaled = scaler.transform(last_60_days)
90
 
 
98
  predicted_price = model.predict(X_test)
99
  predicted_price = scaler.inverse_transform(predicted_price)
100
 
 
101
  return predicted_price
102
 
103
  # Main app function
104
  def stock_predictor(ticker, start_date, end_date):
105
+ usd_to_inr = get_usd_to_inr_rate() # Get the USD to INR conversion rate
 
 
106
 
107
  # Get stock data
108
  stock_data, adjusted_start, adjusted_end = get_stock_data(ticker, start_date, end_date)
 
119
 
120
  if model is None:
121
  # Train the model if pre-trained model is not found
 
122
  model = build_model()
123
  X_train, y_train = [], []
124
  for i in range(60, len(scaled_data)):
 
136
  # Predict stock price for tomorrow
137
  predicted_price = predict_stock(scaled_data, scaler, model)
138
 
139
+ # Convert predicted price to INR
140
+ predicted_price_inr = predicted_price[0][0] * usd_to_inr
141
+
142
+ # Historical vs Predicted Graph
143
  plt.figure(figsize=(14, 7))
144
+ plt.plot(stock_data['Close'], color="blue", label="Historical Prices (USD)")
145
+ plt.scatter(len(stock_data), predicted_price[0], color="red", label="Predicted Price (USD)")
146
  plt.title(f"{ticker} Stock Price Prediction")
147
  plt.xlabel('Date')
148
+ plt.ylabel('Price (USD)')
149
  plt.legend()
150
  plt.show()
151
 
152
+ # Return the predicted price in INR
153
+ return f"Predicted Stock Price for {ticker} tomorrow: ₹{predicted_price_inr:.2f} (INR)"
154
 
155
  # Gradio UI
156
  def build_ui():