File size: 2,761 Bytes
e0f8afd
47d7fa1
d989c3a
47d7fa1
506e7cd
47d7fa1
e0f8afd
 
 
47d7fa1
506e7cd
47d7fa1
 
 
 
506e7cd
 
 
 
4be0554
 
 
 
 
 
 
 
 
 
41f1f7d
78fd827
41f1f7d
a0c768e
41f1f7d
78fd827
41f1f7d
a0c768e
fbbda08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f8afd
 
 
 
 
 
 
 
 
eb2cd48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch
import io

app = FastAPI()

# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]

def preprocess_image(image):
    pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess(pil_image).unsqueeze(0)

def encode_text(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
    return inputs

def encode_image(image):
    inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
    return inputs

def calculate_similarities(query_image, query_text):
    query_image_features = model.get_image_features(query_image)
    query_text_features = model.get_text_features(query_text)

    similarities = []
    for product in deepfashion_database:
        product_image_features = torch.Tensor(product["image_features"])
        product_text_features = torch.Tensor(product["text_features"])

        image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
        text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
        
        similarity_score = image_similarity * text_similarity
        similarities.append(similarity_score)

    return similarities

def initial_query(image, text):
    query_image = encode_image(image)
    query_text = encode_text(text)

    similarities = calculate_similarities(query_image, query_text)
    sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
    top_3_indices = sorted_indices[:3]

    top_3_products = [deepfashion_database[i] for i in top_3_indices]
    return top_3_products

@app.post("/initial_query/")
async def api_initial_query(text: str, image: UploadFile = File(None)):
    if image is not None:
        image_content = await image.read()
        image = Image.open(io.BytesIO(image_content))
        image = preprocess_image(image)
    else:
        image = None
    top_3_products = initial_query(image, text)
    return {"top_3_products": top_3_products}