wisdom196473 commited on
Commit
44f740c
·
1 Parent(s): 43acb67

Update README, model.py, and requirements.txt

Browse files
.ipynb_checkpoints/README-checkpoint.md CHANGED
@@ -41,6 +41,8 @@ streamlit run amazon_app.py
41
  - `model.py`: Core AI model implementations
42
  - `requirements.txt`: Project dependencies
43
 
44
- ## License
45
 
46
- MIT License
 
 
 
41
  - `model.py`: Core AI model implementations
42
  - `requirements.txt`: Project dependencies
43
 
44
+ ## Future Directions
45
 
46
+ - [ ] Fine-Tune FashionClip embedding model based on the specific domain data
47
+ - [ ] Fine-Tune large language model to improve its generalization capabilities
48
+ - [ ] Develop feedback loops for continuous improvement
.ipynb_checkpoints/amazon_app-checkpoint.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # Configure page
4
+ st.set_page_config(
5
+ page_title="E-commerce Visual Assistant",
6
+ page_icon="🛍️",
7
+ layout="wide"
8
+ )
9
+
10
+ from streamlit_chat import message
11
+ import torch
12
+ from PIL import Image
13
+ import requests
14
+ from io import BytesIO
15
+ from model import initialize_models, load_data, chatbot, cleanup_resources
16
+
17
+ # Helper functions
18
+ def load_image_from_url(url):
19
+ try:
20
+ response = requests.get(url)
21
+ img = Image.open(BytesIO(response.content))
22
+ return img
23
+ except Exception as e:
24
+ st.error(f"Error loading image from URL: {str(e)}")
25
+ return None
26
+
27
+ def initialize_assistant():
28
+ if not st.session_state.models_loaded:
29
+ with st.spinner("Loading models and data..."):
30
+ initialize_models()
31
+ load_data()
32
+ st.session_state.models_loaded = True
33
+ st.success("Assistant is ready!")
34
+
35
+ def display_chat_history():
36
+ for message in st.session_state.messages:
37
+ with st.chat_message(message["role"]):
38
+ st.markdown(message["content"])
39
+ if "image" in message:
40
+ st.image(message["image"], caption="Uploaded Image", width=200)
41
+ if "display_images" in message:
42
+ # Since we only have one image, we don't need multiple columns
43
+ img_data = message["display_images"][0] # Get the first (and only) image
44
+ st.image(
45
+ img_data['image'],
46
+ caption=f"{img_data['product_name']}\nPrice: ${img_data['price']:.2f}",
47
+ width=350 # Adjusted width for single image display
48
+ )
49
+
50
+ def handle_user_input(prompt, uploaded_image):
51
+ # Add user message
52
+ st.session_state.messages.append({"role": "user", "content": prompt})
53
+
54
+ # Generate response
55
+ with st.spinner("Processing your request..."):
56
+ try:
57
+ response = chatbot(prompt, image_input=uploaded_image)
58
+
59
+ if isinstance(response, dict):
60
+ assistant_message = {
61
+ "role": "assistant",
62
+ "content": response['text']
63
+ }
64
+ if 'images' in response and response['images']:
65
+ assistant_message["display_images"] = response['images']
66
+ st.session_state.messages.append(assistant_message)
67
+ else:
68
+ st.session_state.messages.append({
69
+ "role": "assistant",
70
+ "content": response
71
+ })
72
+
73
+ except Exception as e:
74
+ st.error(f"Error: {str(e)}")
75
+ st.session_state.messages.append({
76
+ "role": "assistant",
77
+ "content": f"I encountered an error: {str(e)}"
78
+ })
79
+
80
+ st.rerun()
81
+
82
+ # Custom CSS for enhanced styling
83
+ st.markdown("""
84
+ <style>
85
+ /* Main container styling */
86
+ .main {
87
+ background: linear-gradient(135deg, #f5f7fa 0%, #e8edf2 100%);
88
+ padding: 20px;
89
+ border-radius: 15px;
90
+ }
91
+
92
+ /* Header styling */
93
+ .stTitle {
94
+ color: #1e3d59;
95
+ font-size: 2.5rem !important;
96
+ text-align: center;
97
+ padding: 20px;
98
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
99
+ }
100
+
101
+ /* Sidebar styling */
102
+ .css-1d391kg {
103
+ background: linear-gradient(180deg, #1e3d59 0%, #2b5876 100%);
104
+ }
105
+
106
+ /* Chat container styling */
107
+ .stChatMessage {
108
+ background-color: white;
109
+ border-radius: 15px;
110
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
111
+ margin: 10px 0;
112
+ padding: 15px;
113
+ }
114
+
115
+ /* Input box styling */
116
+ .stTextInput > div > div > input {
117
+ border-radius: 20px;
118
+ border: 2px solid #1e3d59;
119
+ padding: 10px 20px;
120
+ }
121
+
122
+ /* Radio button styling */
123
+ .stRadio > label {
124
+ background-color: white;
125
+ padding: 10px 20px;
126
+ border-radius: 10px;
127
+ margin: 5px;
128
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
129
+ }
130
+
131
+ /* Button styling */
132
+ .stButton > button {
133
+ background: linear-gradient(90deg, #1e3d59 0%, #2b5876 100%);
134
+ color: white;
135
+ border-radius: 20px;
136
+ padding: 10px 25px;
137
+ border: none;
138
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
139
+ transition: all 0.3s ease;
140
+ }
141
+
142
+ .stButton > button:hover {
143
+ transform: translateY(-2px);
144
+ box-shadow: 0 6px 8px rgba(0,0,0,0.2);
145
+ }
146
+
147
+ /* Footer styling */
148
+ footer {
149
+ background-color: white;
150
+ border-radius: 10px;
151
+ padding: 20px;
152
+ margin-top: 30px;
153
+ text-align: center;
154
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
155
+ }
156
+ </style>
157
+ """, unsafe_allow_html=True)
158
+
159
+ # Initialize session state
160
+ if 'messages' not in st.session_state:
161
+ st.session_state.messages = []
162
+ if 'models_loaded' not in st.session_state:
163
+ st.session_state.models_loaded = False
164
+
165
+ # Main title with enhanced styling
166
+ st.markdown("<h1 class='stTitle'>🛍️ Amazon E-commerce Visual Assistant</h1>", unsafe_allow_html=True)
167
+
168
+ # Sidebar configuration with enhanced styling
169
+ with st.sidebar:
170
+ st.title("Assistant Features")
171
+
172
+ st.markdown("### 🤖 How It Works")
173
+ st.markdown("""
174
+ This AI-powered shopping assistant combines:
175
+
176
+ **🧠 Advanced Technologies**
177
+ - FashionCLIP Visual AI
178
+ - Mistral-7B Language Model
179
+ - Multimodal Understanding
180
+
181
+ **💫 Capabilities**
182
+ - Product Search & Recognition
183
+ - Visual Analysis
184
+ - Detailed Comparisons
185
+ - Price Analysis
186
+ """)
187
+
188
+ st.markdown("---")
189
+
190
+ st.markdown("### 👥 Development Team")
191
+ team_members = {
192
+ "Yu-Chih (Wisdom) Chen",
193
+ "Feier Xu",
194
+ "Yanchen Dong",
195
+ "Kitae Kim"
196
+ }
197
+
198
+ for name in team_members:
199
+ st.markdown(f"**{name}**")
200
+
201
+ st.markdown("---")
202
+
203
+ if st.button("🔄 Reset Chat"):
204
+ st.session_state.messages = []
205
+ st.rerun()
206
+
207
+ # Main chat interface
208
+ def main():
209
+ # Initialize assistant
210
+ initialize_assistant()
211
+
212
+ # Chat container
213
+ chat_container = st.container()
214
+
215
+ # User input section at the bottom
216
+ input_container = st.container()
217
+
218
+ with input_container:
219
+ # Chat input
220
+ prompt = st.chat_input("What would you like to know?")
221
+
222
+ # Input options below chat input
223
+ col1, col2, col3 = st.columns([1,1,1])
224
+ with col1:
225
+ input_option = st.radio(
226
+ "Input Method:",
227
+ ("Text Only", "Upload Image", "Image URL"),
228
+ key="input_method"
229
+ )
230
+
231
+ # Handle different input methods
232
+ uploaded_image = None
233
+ if input_option == "Upload Image":
234
+ with col2:
235
+ uploaded_file = st.file_uploader("Choose image", type=["jpg", "jpeg", "png"])
236
+ if uploaded_file:
237
+ uploaded_image = Image.open(uploaded_file)
238
+ st.image(uploaded_image, caption="Uploaded Image", width=200)
239
+
240
+ elif input_option == "Image URL":
241
+ with col2:
242
+ image_url = st.text_input("Enter image URL")
243
+ if image_url:
244
+ uploaded_image = load_image_from_url(image_url)
245
+ if uploaded_image:
246
+ st.image(uploaded_image, caption="Image from URL", width=200)
247
+
248
+ # Display chat history
249
+ with chat_container:
250
+ display_chat_history()
251
+
252
+ # Handle user input and generate response
253
+ if prompt:
254
+ handle_user_input(prompt, uploaded_image)
255
+
256
+ # Footer
257
+ st.markdown("""
258
+ <footer>
259
+ <h3>💡 Tips for Best Results</h3>
260
+ <p>Be specific in your questions for more accurate responses!</p>
261
+ <p>Try asking about product features, comparisons, or prices.</p>
262
+ </footer>
263
+ """, unsafe_allow_html=True)
264
+
265
+ if __name__ == "__main__":
266
+ try:
267
+ main()
268
+ finally:
269
+ cleanup_resources()
.ipynb_checkpoints/model-checkpoint.py CHANGED
@@ -47,12 +47,6 @@ text_faiss: Optional[object] = None
47
  image_faiss: Optional[object] = None
48
 
49
  def initialize_models() -> bool:
50
- """
51
- Initialize CLIP and LLM models with proper error handling and GPU optimization.
52
-
53
- Returns:
54
- bool: True if initialization successful, raises RuntimeError otherwise
55
- """
56
  global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
57
 
58
  try:
@@ -80,10 +74,14 @@ def initialize_models() -> bool:
80
  bnb_4bit_quant_type="nf4"
81
  )
82
 
 
 
 
83
  llm_tokenizer = AutoTokenizer.from_pretrained(
84
  model_name,
85
  padding_side="left",
86
- truncation_side="left"
 
87
  )
88
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
89
 
@@ -91,7 +89,8 @@ def initialize_models() -> bool:
91
  model_name,
92
  quantization_config=quantization_config,
93
  device_map="auto",
94
- torch_dtype=torch.float16
 
95
  )
96
  llm_model.eval()
97
  print("LLM initialized successfully")
 
47
  image_faiss: Optional[object] = None
48
 
49
  def initialize_models() -> bool:
 
 
 
 
 
 
50
  global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
51
 
52
  try:
 
74
  bnb_4bit_quant_type="nf4"
75
  )
76
 
77
+ # Get token from Streamlit secrets
78
+ hf_token = st.secrets["HUGGINGFACE_TOKEN"]
79
+
80
  llm_tokenizer = AutoTokenizer.from_pretrained(
81
  model_name,
82
  padding_side="left",
83
+ truncation_side="left",
84
+ token=hf_token # Add token here
85
  )
86
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
87
 
 
89
  model_name,
90
  quantization_config=quantization_config,
91
  device_map="auto",
92
+ torch_dtype=torch.float16,
93
+ token=hf_token # Add token here
94
  )
95
  llm_model.eval()
96
  print("LLM initialized successfully")
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.28.2
2
+ streamlit-chat==0.1.1
3
+ torch>=2.0.0
4
+ transformers==4.35.2
5
+ open_clip_torch==2.23.0
6
+ pillow==10.1.0
7
+ pandas==2.1.3
8
+ numpy==1.26.2
9
+ faiss-cpu>=1.7.4
10
+ huggingface_hub==0.19.4
11
+ langchain==0.0.339
12
+ requests==2.31.0
13
+ bitsandbytes>=0.41.1
14
+ matplotlib==3.7.1
README.md CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  # Amazon E-commerce Visual Assistant
3
 
4
  A multimodal AI assistant that helps users search and explore Amazon products through natural language and image-based interactions.
@@ -42,7 +41,8 @@ streamlit run amazon_app.py
42
  - `model.py`: Core AI model implementations
43
  - `requirements.txt`: Project dependencies
44
 
45
- ## License
46
 
47
- MIT License
48
- =======
 
 
 
1
  # Amazon E-commerce Visual Assistant
2
 
3
  A multimodal AI assistant that helps users search and explore Amazon products through natural language and image-based interactions.
 
41
  - `model.py`: Core AI model implementations
42
  - `requirements.txt`: Project dependencies
43
 
44
+ ## Future Directions
45
 
46
+ - [ ] Fine-Tune FashionClip embedding model based on the specific domain data
47
+ - [ ] Fine-Tune large language model to improve its generalization capabilities
48
+ - [ ] Develop feedback loops for continuous improvement
model.py CHANGED
@@ -47,12 +47,6 @@ text_faiss: Optional[object] = None
47
  image_faiss: Optional[object] = None
48
 
49
  def initialize_models() -> bool:
50
- """
51
- Initialize CLIP and LLM models with proper error handling and GPU optimization.
52
-
53
- Returns:
54
- bool: True if initialization successful, raises RuntimeError otherwise
55
- """
56
  global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
57
 
58
  try:
@@ -80,10 +74,14 @@ def initialize_models() -> bool:
80
  bnb_4bit_quant_type="nf4"
81
  )
82
 
 
 
 
83
  llm_tokenizer = AutoTokenizer.from_pretrained(
84
  model_name,
85
  padding_side="left",
86
- truncation_side="left"
 
87
  )
88
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
89
 
@@ -91,7 +89,8 @@ def initialize_models() -> bool:
91
  model_name,
92
  quantization_config=quantization_config,
93
  device_map="auto",
94
- torch_dtype=torch.float16
 
95
  )
96
  llm_model.eval()
97
  print("LLM initialized successfully")
 
47
  image_faiss: Optional[object] = None
48
 
49
  def initialize_models() -> bool:
 
 
 
 
 
 
50
  global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
51
 
52
  try:
 
74
  bnb_4bit_quant_type="nf4"
75
  )
76
 
77
+ # Get token from Streamlit secrets
78
+ hf_token = st.secrets["HUGGINGFACE_TOKEN"]
79
+
80
  llm_tokenizer = AutoTokenizer.from_pretrained(
81
  model_name,
82
  padding_side="left",
83
+ truncation_side="left",
84
+ token=hf_token # Add token here
85
  )
86
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
87
 
 
89
  model_name,
90
  quantization_config=quantization_config,
91
  device_map="auto",
92
+ torch_dtype=torch.float16,
93
+ token=hf_token # Add token here
94
  )
95
  llm_model.eval()
96
  print("LLM initialized successfully")
requirements.txt CHANGED
@@ -1,14 +1,14 @@
1
  streamlit==1.28.2
2
  streamlit-chat==0.1.1
3
- torch==2.1.1
4
  transformers==4.35.2
5
  open_clip_torch==2.23.0
6
  pillow==10.1.0
7
  pandas==2.1.3
8
  numpy==1.26.2
9
- faiss-cpu==1.7.4
10
  huggingface_hub==0.19.4
11
  langchain==0.0.339
12
  requests==2.31.0
13
- pyngrok==7.0.3
14
- bitsandbytes==0.41.1
 
1
  streamlit==1.28.2
2
  streamlit-chat==0.1.1
3
+ torch>=2.0.0
4
  transformers==4.35.2
5
  open_clip_torch==2.23.0
6
  pillow==10.1.0
7
  pandas==2.1.3
8
  numpy==1.26.2
9
+ faiss-cpu>=1.7.4
10
  huggingface_hub==0.19.4
11
  langchain==0.0.339
12
  requests==2.31.0
13
+ bitsandbytes>=0.41.1
14
+ matplotlib==3.7.1