Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import CLIPProcessor, CLIPModel | |
from datasets import load_dataset | |
import torch | |
import io | |
app = FastAPI() | |
# Load the pre-trained CLIP model and its tokenizer | |
model_name = "openai/clip-vit-base-patch32" | |
processor = CLIPProcessor.from_pretrained(model_name) | |
model = CLIPModel.from_pretrained(model_name) | |
# Load the fashion product images dataset from Hugging Face | |
dataset = load_dataset("ashraq/fashion-product-images-small") | |
deepfashion_database = dataset["train"] | |
def preprocess_image(image): | |
pil_image = Image.fromarray(image.astype('uint8'), 'RGB') | |
preprocess = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return preprocess(pil_image).unsqueeze(0) | |
def encode_text(text): | |
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) | |
return inputs | |
def encode_image(image): | |
inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True) | |
return inputs | |
def calculate_similarities(query_image, query_text): | |
query_image_features = model.get_image_features(query_image) | |
query_text_features = model.get_text_features(query_text) | |
similarities = [] | |
for product in deepfashion_database: | |
product_image_features = torch.Tensor(product["image_features"]) | |
product_text_features = torch.Tensor(product["text_features"]) | |
image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features) | |
text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features) | |
similarity_score = image_similarity * text_similarity | |
similarities.append(similarity_score) | |
return similarities | |
def initial_query(image, text): | |
query_image = encode_image(image) | |
query_text = encode_text(text) | |
similarities = calculate_similarities(query_image, query_text) | |
sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True) | |
top_3_indices = sorted_indices[:3] | |
top_3_products = [deepfashion_database[i] for i in top_3_indices] | |
return top_3_products | |
def send_message(txt, btn): | |
if btn is not None: | |
image = Image.open(btn) | |
image = preprocess_image(image) | |
else: | |
image = None | |
top_3_products = initial_query(image, txt) | |
output_html = generate_output_html(top_3_products) | |
chatbot.append_message("You", txt) | |
chatbot.append_message("AI", output_html) | |
chatbot = gr.Chatbot([]).style(height=750) | |
txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False) | |
btn = gr.UploadButton("π", file_types=["image", "video", "audio"]) | |
gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch() | |
async def api_initial_query(text: str, image: UploadFile = File(None)): | |
if image is not None: | |
image_content = await image.read() | |
image = Image.open(io.BytesIO(image_content)) | |
image = preprocess_image(image) | |
else: | |
image = None | |
top_3_products = initial_query(image, text) | |
return {"top_3_products": top_3_products} |