File size: 6,969 Bytes
5db0821
 
 
 
 
 
 
 
 
44014d4
7d37333
44014d4
7d37333
 
 
 
 
5db0821
 
 
 
44014d4
5db0821
 
 
 
 
 
44014d4
 
 
 
 
 
5db0821
44014d4
 
 
 
 
 
 
 
 
 
70e228b
cbc4322
 
 
70e228b
cbc4322
 
 
7d37333
70e228b
 
 
 
cbc4322
 
 
7d37333
 
cbc4322
 
70e228b
 
 
 
 
 
 
 
 
 
cbc4322
70e228b
 
 
 
 
 
 
 
cbc4322
 
44014d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ddea50
 
 
44014d4
 
 
 
 
 
 
5db0821
 
3ddea50
5db0821
 
 
 
 
 
 
44014d4
5db0821
 
44014d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db0821
44014d4
5db0821
 
 
 
 
 
 
 
 
 
 
 
44014d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import streamlit as st
import open_clip
import torch
import requests
from PIL import Image
from io import BytesIO
import time
import json
import numpy as np
import cv2
from inference_sdk import InferenceHTTPClient
import matplotlib.pyplot as plt
import base64

# ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์˜ˆ์™ธ ํด๋ž˜์Šค ์ •์˜
class APIError(Exception):
    pass

# Load model and tokenizer
@st.cache_resource
def load_model():
    model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, preprocess_val, tokenizer, device

model, preprocess_val, tokenizer, device = load_model()

# Roboflow client setup function
def setup_roboflow_client(api_key):
    return InferenceHTTPClient(
        api_url="https://outline.roboflow.com",
        api_key=api_key
    )

# Streamlit app
st.title("Fashion Search App with Segmentation")

# API Key input
api_key = st.text_input("Enter your Roboflow API Key", type="password")

if api_key:
    CLIENT = setup_roboflow_client(api_key)

    def segment_image(image_path):
        try:
            # ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฝ๊ธฐ
            with open(image_path, "rb") as image_file:
                image_data = image_file.read()
            
            # ์ด๋ฏธ์ง€๋ฅผ base64๋กœ ์ธ์ฝ”๋”ฉ
            encoded_image = base64.b64encode(image_data).decode('utf-8')
            
            # ์›๋ณธ ์ด๋ฏธ์ง€ ๋กœ๋“œ
            image = cv2.imread(image_path)
            image = cv2.resize(image, (800, 600))
            mask = np.zeros(image.shape, dtype=np.uint8)
            
            try:
                # Roboflow API ํ˜ธ์ถœ
                results = CLIENT.infer(encoded_image, model_id="closet/1")
            except Exception as api_error:
                st.error(f"API Error: {str(api_error)}")
                return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            
            if 'predictions' in results:
                for prediction in results['predictions']:
                    points = prediction['points']
                    pts = np.array([[p['x'], p['y']] for p in points], np.int32)
                    scale_x = image.shape[1] / results['image']['width']
                    scale_y = image.shape[0] / results['image']['height']
                    pts = pts * [scale_x, scale_y]
                    pts = pts.astype(np.int32)
                    pts = pts.reshape((-1, 1, 2))
                    cv2.fillPoly(mask, [pts], color=(255, 255, 255))  # White mask
                
                segmented_image = cv2.bitwise_and(image, mask)
            else:
                st.warning("No predictions found in the image. Returning original image.")
                segmented_image = image
            
            return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
        except Exception as e:
            st.error(f"Error in segmentation: {str(e)}")
            # ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์‹œ ์ฝ์–ด ๋ฐ˜ํ™˜
            return Image.open(image_path)
    def get_image_embedding(image):
        image_tensor = preprocess_val(image).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(image_tensor)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.cpu().numpy()

    # Load and process data
    @st.cache_data
    def load_data():
        with open('musinsa-final.json', 'r', encoding='utf-8') as f:
            return json.load(f)

    data = load_data()

    # Process database with segmentation
    @st.cache_data
    def process_database():
        database_embeddings = []
        database_info = []
        for item in data:
            image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
            # '\ufeff์ƒํ’ˆ ID' ๋Œ€์‹  '์ƒํ’ˆ ID'๋ฅผ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ˆ˜์ •
            product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
            image_path = "temp_{}.jpg".format(product_id)
            response = requests.get(image_url)
            with open(image_path, 'wb') as f:
                f.write(response.content)
            
            segmented_image = segment_image(image_path)
            embedding = get_image_embedding(segmented_image)
            
            database_embeddings.append(embedding)
            database_info.append({
                'id': product_id,
                'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
                'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
                'name': item['์ œํ’ˆ๋ช…'],
                'price': item['์ •๊ฐ€'],
                'discount': item['ํ• ์ธ์œจ'],
                'image_url': image_url
            })
        
        return np.vstack(database_embeddings), database_info

    database_embeddings, database_info = process_database()

    def find_similar_images(query_embedding, top_k=5):
        similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
        top_indices = np.argsort(similarities)[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            results.append({
                'info': database_info[idx],
                'similarity': similarities[idx]
            })
        
        return results

    uploaded_file = st.file_uploader("Choose an image...", type="jpg")
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Image', use_column_width=True)
        
        if st.button('Find Similar Items'):
            with st.spinner('Processing...'):
                # Save uploaded image temporarily
                temp_path = "temp_upload.jpg"
                image.save(temp_path)
                
                # Segment the uploaded image
                segmented_image = segment_image(temp_path)
                st.image(segmented_image, caption='Segmented Image', use_column_width=True)
                
                # Get embedding for segmented image
                query_embedding = get_image_embedding(segmented_image)
                similar_images = find_similar_images(query_embedding)
                
                st.subheader("Similar Items:")
                for img in similar_images:
                    col1, col2 = st.columns(2)
                    with col1:
                        st.image(img['info']['image_url'], use_column_width=True)
                    with col2:
                        st.write(f"Name: {img['info']['name']}")
                        st.write(f"Brand: {img['info']['brand']}")
                        st.write(f"Category: {img['info']['category']}")
                        st.write(f"Price: {img['info']['price']}")
                        st.write(f"Discount: {img['info']['discount']}%")
                        st.write(f"Similarity: {img['similarity']:.2f}")
else:
    st.warning("Please enter your Roboflow API Key to use the app.")