File size: 3,387 Bytes
eb2cd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch
import io

app = FastAPI()

# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]

def preprocess_image(image):
    pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess(pil_image).unsqueeze(0)

def encode_text(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
    return inputs

def encode_image(image):
    inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
    return inputs

def calculate_similarities(query_image, query_text):
    query_image_features = model.get_image_features(query_image)
    query_text_features = model.get_text_features(query_text)

    similarities = []
    for product in deepfashion_database:
        product_image_features = torch.Tensor(product["image_features"])
        product_text_features = torch.Tensor(product["text_features"])

        image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
        text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
        
        similarity_score = image_similarity * text_similarity
        similarities.append(similarity_score)

    return similarities

def initial_query(image, text):
    query_image = encode_image(image)
    query_text = encode_text(text)

    similarities = calculate_similarities(query_image, query_text)
    sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
    top_3_indices = sorted_indices[:3]

    top_3_products = [deepfashion_database[i] for i in top_3_indices]
    return top_3_products

def send_message(txt, btn):
    if btn is not None:
        image = Image.open(btn)
        image = preprocess_image(image)
    else:
        image = None
    top_3_products = initial_query(image, txt)
    output_html = generate_output_html(top_3_products)
    chatbot.append_message("You", txt)
    chatbot.append_message("AI", output_html)


chatbot = gr.Chatbot([]).style(height=750)
txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False)
btn = gr.UploadButton("πŸ“", file_types=["image", "video", "audio"])

gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()
@app.post("/initial_query/")
async def api_initial_query(text: str, image: UploadFile = File(None)):
    if image is not None:
        image_content = await image.read()
        image = Image.open(io.BytesIO(image_content))
        image = preprocess_image(image)
    else:
        image = None
    top_3_products = initial_query(image, text)
    return {"top_3_products": top_3_products}