File size: 8,053 Bytes
1fd9ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838a59a
1fd9ec2
 
 
 
 
 
 
838a59a
1fd9ec2
 
 
838a59a
1fd9ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import streamlit as st

# Configure page
st.set_page_config(
    page_title="E-commerce Visual Assistant",
    page_icon="πŸ›οΈ",
    layout="wide"
)

from streamlit_chat import message
import torch
from PIL import Image
import requests
from io import BytesIO
from model import initialize_models, load_data, chatbot, cleanup_resources

# Helper functions
def load_image_from_url(url):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content))
        return img
    except Exception as e:
        st.error(f"Error loading image from URL: {str(e)}")
        return None

def initialize_assistant():
    if not st.session_state.models_loaded:
        with st.spinner("Loading models and data..."):
            initialize_models()
            load_data()
            st.session_state.models_loaded = True
        st.success("Assistant is ready!")
        
def display_chat_history():
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
            if "image" in message:
                st.image(message["image"], caption="Uploaded Image", width=200)
            if "display_images" in message:
                img_data = message["display_images"][0]
                st.image(
                    img_data['image'],
                    caption=f"{img_data['product_name']}\nPrice: ${img_data['price']:.2f}",
                    width=350 
                )

def handle_user_input(prompt, uploaded_image):
    # Add user message
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    # Generate response
    with st.spinner("Processing your request..."):
        try:
            response = chatbot(prompt, image_input=uploaded_image)
            
            if isinstance(response, dict):
                assistant_message = {
                    "role": "assistant",
                    "content": response['text']
                }
                if 'images' in response and response['images']:
                    assistant_message["display_images"] = response['images']
                st.session_state.messages.append(assistant_message)
            else:
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": response
                })
                
        except Exception as e:
            st.error(f"Error: {str(e)}")
            st.session_state.messages.append({
                "role": "assistant",
                "content": f"I encountered an error: {str(e)}"
            })
    
    st.rerun()

# Custom CSS for enhanced styling
st.markdown("""
    <style>
        /* Main container styling */
        .main {
            background: linear-gradient(135deg, #f5f7fa 0%, #e8edf2 100%);
            padding: 20px;
            border-radius: 15px;
        }
        
        /* Header styling */
        .stTitle {
            color: #1e3d59;
            font-size: 2.5rem !important;
            text-align: center;
            padding: 20px;
            text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
        }
        
        /* Sidebar styling */
        .css-1d391kg {
            background: linear-gradient(180deg, #1e3d59 0%, #2b5876 100%);
        }
        
        /* Chat container styling */
        .stChatMessage {
            background-color: white;
            border-radius: 15px;
            box-shadow: 0 4px 6px rgba(0,0,0,0.1);
            margin: 10px 0;
            padding: 15px;
        }
        
        /* Input box styling */
        .stTextInput > div > div > input {
            border-radius: 20px;
            border: 2px solid #1e3d59;
            padding: 10px 20px;
        }
        
        /* Radio button styling */
        .stRadio > label {
            background-color: white;
            padding: 10px 20px;
            border-radius: 10px;
            margin: 5px;
            box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        }
        
        /* Button styling */
        .stButton > button {
            background: linear-gradient(90deg, #1e3d59 0%, #2b5876 100%);
            color: white;
            border-radius: 20px;
            padding: 10px 25px;
            border: none;
            box-shadow: 0 4px 6px rgba(0,0,0,0.1);
            transition: all 0.3s ease;
        }
        
        .stButton > button:hover {
            transform: translateY(-2px);
            box-shadow: 0 6px 8px rgba(0,0,0,0.2);
        }
        
        /* Footer styling */
        footer {
            background-color: white;
            border-radius: 10px;
            padding: 20px;
            margin-top: 30px;
            text-align: center;
            box-shadow: 0 4px 6px rgba(0,0,0,0.1);
        }
    </style>
""", unsafe_allow_html=True)

# Initialize session state
if 'messages' not in st.session_state:
    st.session_state.messages = []
if 'models_loaded' not in st.session_state:
    st.session_state.models_loaded = False

# Main title with enhanced styling
st.markdown("<h1 class='stTitle'>πŸ›οΈ Amazon E-commerce Visual Assistant</h1>", unsafe_allow_html=True)

# Sidebar configuration with enhanced styling
with st.sidebar:
    st.title("Assistant Features")
    
    st.markdown("### πŸ€– How It Works")
    st.markdown("""
    This AI-powered shopping assistant combines:
    
    **🧠 Advanced Technologies**
    - FashionCLIP Visual AI
    - Mistral-7B Language Model
    - Multimodal Understanding
    
    **πŸ’« Capabilities**
    - Product Search & Recognition
    - Visual Analysis
    - Detailed Comparisons
    - Price Analysis
    """)
    
    st.markdown("---")
    
    st.markdown("### πŸ‘₯ Development Team")
    team_members = {
        "Yu-Chih (Wisdom) Chen",
        "Feier Xu",
        "Yanchen Dong",
        "Kitae Kim"
    }
    
    for name in team_members:
        st.markdown(f"**{name}**")
    
    st.markdown("---")
    
    if st.button("πŸ”„ Reset Chat"):
        st.session_state.messages = []
        st.rerun()

# Main chat interface
def main():
    # Initialize assistant
    initialize_assistant()
    
    # Chat container
    chat_container = st.container()
    
    # User input section at the bottom
    input_container = st.container()
    
    with input_container:
        # Chat input
        prompt = st.chat_input("What would you like to know?")
        
        # Input options below chat input
        col1, col2, col3 = st.columns([1,1,1])
        with col1:
            input_option = st.radio(
                "Input Method:",
                ("Text Only", "Upload Image", "Image URL"),
                key="input_method"
            )
        
        # Handle different input methods
        uploaded_image = None
        if input_option == "Upload Image":
            with col2:
                uploaded_file = st.file_uploader("Choose image", type=["jpg", "jpeg", "png"])
                if uploaded_file:
                    uploaded_image = Image.open(uploaded_file)
                    st.image(uploaded_image, caption="Uploaded Image", width=200)
        
        elif input_option == "Image URL":
            with col2:
                image_url = st.text_input("Enter image URL")
                if image_url:
                    uploaded_image = load_image_from_url(image_url)
                    if uploaded_image:
                        st.image(uploaded_image, caption="Image from URL", width=200)
    
    # Display chat history
    with chat_container:
        display_chat_history()
    
    # Handle user input and generate response
    if prompt:
        handle_user_input(prompt, uploaded_image)

    # Footer
    st.markdown("""
    <footer>
        <h3>πŸ’‘ Tips for Best Results</h3>
        <p>Be specific in your questions for more accurate responses!</p>
        <p>Try asking about product features, comparisons, or prices.</p>
    </footer>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    try:
        main()
    finally:
        cleanup_resources()