from fastapi import FastAPI, File, UploadFile from PIL import Image from torchvision import transforms from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset import torch import io app = FastAPI() # Load the pre-trained CLIP model and its tokenizer model_name = "openai/clip-vit-base-patch32" processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name) # Load the fashion product images dataset from Hugging Face dataset = load_dataset("ashraq/fashion-product-images-small") deepfashion_database = dataset["train"] def preprocess_image(image): pil_image = Image.fromarray(image.astype('uint8'), 'RGB') preprocess = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return preprocess(pil_image).unsqueeze(0) def encode_text(text): inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) return inputs def encode_image(image): inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True) return inputs def calculate_similarities(query_image, query_text): query_image_features = model.get_image_features(query_image) query_text_features = model.get_text_features(query_text) similarities = [] for product in deepfashion_database: product_image_features = torch.Tensor(product["image_features"]) product_text_features = torch.Tensor(product["text_features"]) image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features) text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features) similarity_score = image_similarity * text_similarity similarities.append(similarity_score) return similarities def initial_query(image, text): query_image = encode_image(image) query_text = encode_text(text) similarities = calculate_similarities(query_image, query_text) sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True) top_3_indices = sorted_indices[:3] top_3_products = [deepfashion_database[i] for i in top_3_indices] return top_3_products @app.post("/initial_query/") async def api_initial_query(text: str, image: UploadFile = File(None)): if image is not None: image_content = await image.read() image = Image.open(io.BytesIO(image_content)) image = preprocess_image(image) else: image = None top_3_products = initial_query(image, text) return {"top_3_products": top_3_products}