import gradio as gr from PIL import Image from torchvision import transforms from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset import torch # Load the pre-trained CLIP model and its tokenizer model_name = "openai/clip-vit-base-patch32" processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name) # Load the fashion product images dataset from Hugging Face dataset = load_dataset("ashraq/fashion-product-images-small") deepfashion_database = dataset["train"] def preprocess_image(image): pil_image = Image.fromarray(image.astype('uint8'), 'RGB') preprocess = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return preprocess(pil_image).unsqueeze(0) def encode_text(text): inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) return inputs def encode_image(image): inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True) return inputs def calculate_similarities(query_image, query_text): query_image_features = model.get_image_features(query_image) query_text_features = model.get_text_features(query_text) similarities = [] for product in deepfashion_database: product_image_features = torch.Tensor(product["image_features"]) product_text_features = torch.Tensor(product["text_features"]) image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features) text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features) similarity_score = image_similarity * text_similarity similarities.append(similarity_score) return similarities def initial_query(image, text): query_image = encode_image(image) query_text = encode_text(text) similarities = calculate_similarities(query_image, query_text) sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True) top_3_indices = sorted_indices[:3] top_3_products = [deepfashion_database[i] for i in top_3_indices] return top_3_products def send_message(txt, btn): if btn is not None: image = Image.open(btn) image = preprocess_image(image) else: image = None top_3_products = initial_query(image, txt) output_html = generate_output_html(top_3_products) chatbot.append_message("You", txt) chatbot.append_message("AI", output_html) chatbot = gr.Chatbot([]).style(height=750) txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False) btn = gr.UploadButton("📁", file_types=["image", "video", "audio"]) gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()