harshpatel080503's picture
Upload 6 files
8cfefb8 verified
import datetime
import plotly.graph_objs as go
import requests
import streamlit as st
import yfinance as yf
API_URL = "http://127.0.0.1:8000/LSTM_Predict"
MIN_DATE = datetime.date(2020, 1, 1)
MAX_DATE = datetime.date(2022, 12, 31)
def main():
stock_name = st.selectbox(
"Please choose stock name", ("AAPL", "TSLA", "AMZN", "MSFT")
)
start_date = st.date_input(
"Start date", min_value=MIN_DATE, max_value=MAX_DATE, value=MIN_DATE
)
end_date = st.date_input(
"End date", min_value=MIN_DATE, max_value=MAX_DATE, value=MAX_DATE
)
if start_date <= end_date:
st.success(
f"Selected start date: `{start_date}`\n\nSelected end date:`{end_date}`"
)
else:
st.error("Error: End date must be after start date.")
try:
stock_data = yf.download(stock_name, start=start_date, end=end_date)
stock_data.reset_index(inplace=True)
fig = go.Figure()
fig.add_trace(go.Scatter(x=stock_data.index, y=stock_data["Close"], name="Close"))
fig.update_layout(title=f"{stock_name} Stock Price")
st.plotly_chart(fig)
stock_data.to_csv(f"{stock_name}_data.csv", index=False)
if st.button("Predict"):
payload = {"stock_name": stock_name}
st.write("Sending payload:", payload) # Debugging line
try:
response = requests.post(API_URL, json=payload)
response.raise_for_status()
predictions = response.json()
predicted_prices = predictions["prediction"]
actual_prices = stock_data["Close"].tolist()
fig = go.Figure()
fig.add_trace(
go.Scatter(x=stock_data.index, y=actual_prices, name="Actual")
)
fig.add_trace(
go.Scatter(
x=stock_data.index[-len(predicted_prices):],
y=predicted_prices,
name="Predicted",
)
)
fig.update_layout(title=f"{stock_name} Stock Price Prediction")
st.plotly_chart(fig)
except requests.exceptions.HTTPError as e:
st.error(f"HTTP error occurred: {e}")
st.write(response.json()) # Print detailed error message from FastAPI
except Exception as e:
st.error(f"An error occurred while downloading stock data: {e}")
if __name__ == "__main__":
main()