Spaces:
Runtime error
Runtime error
import streamlit as st | |
import open_clip | |
import torch | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import time | |
import json | |
import numpy as np | |
import cv2 | |
import chromadb | |
from transformers import YolosImageProcessor, YolosForObjectDetection | |
# Load CLIP model and tokenizer | |
def load_clip_model(): | |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') | |
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
return model, preprocess_val, tokenizer, device | |
clip_model, preprocess_val, tokenizer, device = load_clip_model() | |
# Load YOLOS model | |
def load_yolos_model(): | |
processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia") | |
model = YolosForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia") | |
return processor, model | |
yolos_processor, yolos_model = load_yolos_model() | |
# Define the categories | |
CATS = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'] | |
# Helper functions | |
def load_image_from_url(url, max_retries=3): | |
for attempt in range(max_retries): | |
try: | |
response = requests.get(url, timeout=10) | |
response.raise_for_status() | |
img = Image.open(BytesIO(response.content)).convert('RGB') | |
return img | |
except (requests.RequestException, Image.UnidentifiedImageError) as e: | |
if attempt < max_retries - 1: | |
time.sleep(1) | |
else: | |
return None | |
#Load chromaDB | |
client = chromadb.PersistentClient(path="./clothesDB") | |
collection = client.get_collection(name="fashion_items_ver2") | |
def get_image_embedding(image): | |
image_tensor = preprocess_val(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = clip_model.encode_image(image_tensor) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy() | |
def get_text_embedding(text): | |
text_tokens = tokenizer([text]).to(device) | |
with torch.no_grad(): | |
text_features = clip_model.encode_text(text_tokens) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
return text_features.cpu().numpy() | |
def get_all_embeddings_from_collection(collection): | |
all_embeddings = collection.get(include=['embeddings'])['embeddings'] | |
return np.array(all_embeddings) | |
def get_metadata_from_ids(collection, ids): | |
results = collection.get(ids=ids) | |
return results['metadatas'] | |
def find_similar_images(query_embedding, collection, top_k=5): | |
database_embeddings = get_all_embeddings_from_collection(collection) | |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze() | |
top_indices = np.argsort(similarities)[::-1][:top_k] | |
all_data = collection.get(include=['metadatas'])['metadatas'] | |
top_metadatas = [all_data[idx] for idx in top_indices] | |
results = [] | |
for idx, metadata in enumerate(top_metadatas): | |
results.append({ | |
'info': metadata, | |
'similarity': similarities[top_indices[idx]] | |
}) | |
return results | |
def detect_clothing(image): | |
inputs = yolos_processor(images=image, return_tensors="pt") | |
outputs = yolos_model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = yolos_processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0] | |
categories = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
box = [int(i) for i in box.tolist()] | |
category = yolos_model.config.id2label[label.item()] | |
if category in CATS: | |
categories.append({ | |
'category': category, | |
'bbox': box, | |
'confidence': score.item() | |
}) | |
return categories | |
def crop_image(image, bbox): | |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
# Streamlit app | |
st.title("Advanced Fashion Search App") | |
# Initialize session state | |
if 'step' not in st.session_state: | |
st.session_state.step = 'input' | |
if 'query_image_url' not in st.session_state: | |
st.session_state.query_image_url = '' | |
if 'detections' not in st.session_state: | |
st.session_state.detections = [] | |
if 'selected_category' not in st.session_state: | |
st.session_state.selected_category = None | |
# Step-by-step processing | |
if st.session_state.step == 'input': | |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url) | |
if st.button("Detect Clothing"): | |
if st.session_state.query_image_url: | |
query_image = load_image_from_url(st.session_state.query_image_url) | |
if query_image is not None: | |
st.session_state.query_image = query_image | |
st.session_state.detections = detect_clothing(query_image) | |
if st.session_state.detections: | |
st.session_state.step = 'select_category' | |
else: | |
st.warning("No clothing items detected in the image.") | |
else: | |
st.error("Failed to load the image. Please try another URL.") | |
else: | |
st.warning("Please enter an image URL.") | |
elif st.session_state.step == 'select_category': | |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True) | |
st.subheader("Detected Clothing Items:") | |
for detection in st.session_state.detections: | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.write(f"{detection['category']} (Confidence: {detection['confidence']:.2f})") | |
with col2: | |
cropped_image = crop_image(st.session_state.query_image, detection['bbox']) | |
st.image(cropped_image, caption=detection['category'], use_column_width=True) | |
options = [f"{d['category']} (Confidence: {d['confidence']:.2f})" for d in st.session_state.detections] | |
selected_option = st.selectbox("Select a category to search:", options) | |
if st.button("Search Similar Items"): | |
st.session_state.selected_category = selected_option | |
st.session_state.step = 'show_results' | |
elif st.session_state.step == 'show_results': | |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True) | |
selected_detection = next(d for d in st.session_state.detections | |
if f"{d['category']} (Confidence: {d['confidence']:.2f})" == st.session_state.selected_category) | |
cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox']) | |
st.image(cropped_image, caption="Cropped Image", use_column_width=True) | |
query_embedding = get_image_embedding(cropped_image) | |
similar_images = find_similar_images(query_embedding, collection) | |
st.subheader("Similar Items:") | |
for img in similar_images: | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(img['info']['image_url'], use_column_width=True) | |
with col2: | |
st.write(f"Name: {img['info']['name']}") | |
st.write(f"Brand: {img['info']['brand']}") | |
st.write(f"Category: {img['info']['category']}") | |
st.write(f"Price: {img['info']['price']}") | |
st.write(f"Discount: {img['info']['discount']}%") | |
st.write(f"Similarity: {img['similarity']:.2f}") | |
if st.button("Start New Search"): | |
st.session_state.step = 'input' | |
st.session_state.query_image_url = '' | |
st.session_state.detections = [] | |
st.session_state.selected_category = None | |
# Text search | |
st.sidebar.title("Text Search") | |
query_text = st.sidebar.text_input("Enter search text:") | |
if st.sidebar.button("Search by Text"): | |
if query_text: | |
text_embedding = get_text_embedding(query_text) | |
similar_images = find_similar_images(text_embedding, collection) | |
st.sidebar.subheader("Similar Items:") | |
for img in similar_images: | |
st.sidebar.image(img['info']['image_url'], use_column_width=True) | |
st.sidebar.write(f"Name: {img['info']['name']}") | |
st.sidebar.write(f"Brand: {img['info']['brand']}") | |
st.sidebar.write(f"Category: {img['info']['category']}") | |
st.sidebar.write(f"Price: {img['info']['price']}") | |
st.sidebar.write(f"Discount: {img['info']['discount']}%") | |
st.sidebar.write(f"Similarity: {img['similarity']:.2f}") | |
st.sidebar.write("---") | |
else: | |
st.sidebar.warning("Please enter a search text.") |