File size: 2,946 Bytes
47d7fa1
 
d989c3a
47d7fa1
506e7cd
47d7fa1
 
506e7cd
47d7fa1
 
 
 
506e7cd
 
 
 
4be0554
 
 
 
 
 
 
 
 
 
41f1f7d
78fd827
41f1f7d
a0c768e
41f1f7d
78fd827
41f1f7d
a0c768e
fbbda08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9426ca5
fcec08f
504b389
 
fcec08f
 
 
a0c768e
fcec08f
 
5d8493d
504b389
5d8493d
 
9426ca5
fcec08f
9426ca5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch

# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]

def preprocess_image(image):
    pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess(pil_image).unsqueeze(0)

def encode_text(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
    return inputs

def encode_image(image):
    inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
    return inputs

def calculate_similarities(query_image, query_text):
    query_image_features = model.get_image_features(query_image)
    query_text_features = model.get_text_features(query_text)

    similarities = []
    for product in deepfashion_database:
        product_image_features = torch.Tensor(product["image_features"])
        product_text_features = torch.Tensor(product["text_features"])

        image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
        text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
        
        similarity_score = image_similarity * text_similarity
        similarities.append(similarity_score)

    return similarities

def initial_query(image, text):
    query_image = encode_image(image)
    query_text = encode_text(text)

    similarities = calculate_similarities(query_image, query_text)
    sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
    top_3_indices = sorted_indices[:3]

    top_3_products = [deepfashion_database[i] for i in top_3_indices]
    return top_3_products

def send_message(txt, btn):
    if btn is not None:
        image = Image.open(btn)
        image = preprocess_image(image)
    else:
        image = None
    top_3_products = initial_query(image, txt)
    output_html = generate_output_html(top_3_products)
    chatbot.append_message("You", txt)
    chatbot.append_message("AI", output_html)


chatbot = gr.Chatbot([]).style(height=750)
txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False)
btn = gr.UploadButton("πŸ“", file_types=["image", "video", "audio"])

gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()