Spaces:
Runtime error
Runtime error
File size: 2,946 Bytes
47d7fa1 d989c3a 47d7fa1 506e7cd 47d7fa1 506e7cd 47d7fa1 506e7cd 4be0554 41f1f7d 78fd827 41f1f7d a0c768e 41f1f7d 78fd827 41f1f7d a0c768e fbbda08 9426ca5 fcec08f 504b389 fcec08f a0c768e fcec08f 5d8493d 504b389 5d8493d 9426ca5 fcec08f 9426ca5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch
# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)
# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]
def preprocess_image(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return preprocess(pil_image).unsqueeze(0)
def encode_text(text):
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
return inputs
def encode_image(image):
inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
return inputs
def calculate_similarities(query_image, query_text):
query_image_features = model.get_image_features(query_image)
query_text_features = model.get_text_features(query_text)
similarities = []
for product in deepfashion_database:
product_image_features = torch.Tensor(product["image_features"])
product_text_features = torch.Tensor(product["text_features"])
image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
similarity_score = image_similarity * text_similarity
similarities.append(similarity_score)
return similarities
def initial_query(image, text):
query_image = encode_image(image)
query_text = encode_text(text)
similarities = calculate_similarities(query_image, query_text)
sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
top_3_indices = sorted_indices[:3]
top_3_products = [deepfashion_database[i] for i in top_3_indices]
return top_3_products
def send_message(txt, btn):
if btn is not None:
image = Image.open(btn)
image = preprocess_image(image)
else:
image = None
top_3_products = initial_query(image, txt)
output_html = generate_output_html(top_3_products)
chatbot.append_message("You", txt)
chatbot.append_message("AI", output_html)
chatbot = gr.Chatbot([]).style(height=750)
txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False)
btn = gr.UploadButton("π", file_types=["image", "video", "audio"])
gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()
|