File size: 2,526 Bytes
8cfefb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import datetime

import plotly.graph_objs as go
import requests
import streamlit as st
import yfinance as yf

API_URL = "http://127.0.0.1:8000/LSTM_Predict"

MIN_DATE = datetime.date(2020, 1, 1)
MAX_DATE = datetime.date(2022, 12, 31)

def main():
    stock_name = st.selectbox(
        "Please choose stock name", ("AAPL", "TSLA", "AMZN", "MSFT")
    )

    start_date = st.date_input(
        "Start date", min_value=MIN_DATE, max_value=MAX_DATE, value=MIN_DATE
    )
    end_date = st.date_input(
        "End date", min_value=MIN_DATE, max_value=MAX_DATE, value=MAX_DATE
    )

    if start_date <= end_date:
        st.success(
            f"Selected start date: `{start_date}`\n\nSelected end date:`{end_date}`"
        )
    else:
        st.error("Error: End date must be after start date.")

    try:
        stock_data = yf.download(stock_name, start=start_date, end=end_date)
        stock_data.reset_index(inplace=True)

        fig = go.Figure()
        fig.add_trace(go.Scatter(x=stock_data.index, y=stock_data["Close"], name="Close"))
        fig.update_layout(title=f"{stock_name} Stock Price")
        st.plotly_chart(fig)

        stock_data.to_csv(f"{stock_name}_data.csv", index=False)

        if st.button("Predict"):
            payload = {"stock_name": stock_name}
            st.write("Sending payload:", payload)  # Debugging line
            try:
                response = requests.post(API_URL, json=payload)
                response.raise_for_status()
                predictions = response.json()
                predicted_prices = predictions["prediction"]

                actual_prices = stock_data["Close"].tolist()
                fig = go.Figure()
                fig.add_trace(
                    go.Scatter(x=stock_data.index, y=actual_prices, name="Actual")
                )
                fig.add_trace(
                    go.Scatter(
                        x=stock_data.index[-len(predicted_prices):],
                        y=predicted_prices,
                        name="Predicted",
                    )
                )
                fig.update_layout(title=f"{stock_name} Stock Price Prediction")
                st.plotly_chart(fig)
            except requests.exceptions.HTTPError as e:
                st.error(f"HTTP error occurred: {e}")
                st.write(response.json())  # Print detailed error message from FastAPI
    except Exception as e:
        st.error(f"An error occurred while downloading stock data: {e}")

if __name__ == "__main__":
    main()