File size: 3,638 Bytes
65cdc34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from rag import RAGProcessor
import os
from dotenv import load_dotenv
import tempfile

# Load environment variables
load_dotenv()

# Check for API key
if not os.getenv('GOOGLE_API_KEY'):
    st.error("Please set the GOOGLE_API_KEY in your .env file.")
    st.stop()

def initialize_session_state():
    """Initialize session state variables."""
    if "rag_processor" not in st.session_state:
        st.session_state.rag_processor = RAGProcessor()
    if "vector_store" not in st.session_state:
        st.session_state.vector_store = None

def save_uploaded_files(uploaded_files):
    """Save uploaded files to a temporary directory and return file paths."""
    try:
        temp_dir = tempfile.mkdtemp()
        file_paths = []

        for uploaded_file in uploaded_files:
            file_path = os.path.join(temp_dir, uploaded_file.name)
            with open(file_path, "wb") as f:
                f.write(uploaded_file.getbuffer())
            file_paths.append(file_path)

        return file_paths
    except Exception as e:
        st.error(f"Error saving uploaded files: {e}")
        return []

def main():
    st.set_page_config(
        page_title="Finance Buddy",
        page_icon="πŸ’°",
        layout="wide"
    )

    initialize_session_state()

    # Main header with emoji
    st.markdown("<div class='main-header'>", unsafe_allow_html=True)
    st.markdown(
        "<h1 style='text-align: center;'>πŸ’° Finance Buddy</h1>", 
        unsafe_allow_html=True
    )
    st.markdown("</div>", unsafe_allow_html=True)

    # Sidebar
    with st.sidebar:
        st.image("PL_image-removebg-preview.png", use_column_width=True)
        st.title("πŸ“„ Document Analysis")
        uploaded_files = st.file_uploader(
            "Upload P&L Documents (PDF)",
            accept_multiple_files=True,
            type=['pdf']
        )

        if uploaded_files and st.button("Process Documents", key="process_docs"):
            with st.spinner("Processing documents..."):
                try:
                    # Save uploaded files and process them
                    file_paths = save_uploaded_files(uploaded_files)
                    if file_paths:
                        st.session_state.vector_store = st.session_state.rag_processor.process_documents(file_paths)
                        st.success("βœ… Documents processed successfully!")
                except Exception as e:
                    st.error(f"Error processing documents: {e}")

    # Main content
    st.markdown("""

    πŸ’‘ **Ask questions about your P&L statements and financial data.**  

    """)

    # Query input
    query = st.text_input("πŸ” Ask your question:", key="query")

    if query:
        if not st.session_state.vector_store:
            st.warning("Please upload and process documents first!")
        else:
            with st.spinner("Analyzing..."):
                try:
                    response = st.session_state.rag_processor.generate_response(
                        query,
                        st.session_state.vector_store
                    )
                    st.markdown("### πŸ“‹ Response:")
                    st.markdown(f">{response}")
                except Exception as e:
                    st.error(f"Error generating response: {e}")

    # Footer
    st.markdown("---")
    st.markdown(
        "<p style='text-align: center;'>πŸ’Ό Built with Streamlit & Google Generative AI</p>",
        unsafe_allow_html=True
    )

if __name__ == "__main__":
    main()