product-catalog / appv2.py
samnji's picture
step 6
eb2cd48
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch
import io
app = FastAPI()
# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)
# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]
def preprocess_image(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return preprocess(pil_image).unsqueeze(0)
def encode_text(text):
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
return inputs
def encode_image(image):
inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
return inputs
def calculate_similarities(query_image, query_text):
query_image_features = model.get_image_features(query_image)
query_text_features = model.get_text_features(query_text)
similarities = []
for product in deepfashion_database:
product_image_features = torch.Tensor(product["image_features"])
product_text_features = torch.Tensor(product["text_features"])
image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
similarity_score = image_similarity * text_similarity
similarities.append(similarity_score)
return similarities
def initial_query(image, text):
query_image = encode_image(image)
query_text = encode_text(text)
similarities = calculate_similarities(query_image, query_text)
sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
top_3_indices = sorted_indices[:3]
top_3_products = [deepfashion_database[i] for i in top_3_indices]
return top_3_products
def send_message(txt, btn):
if btn is not None:
image = Image.open(btn)
image = preprocess_image(image)
else:
image = None
top_3_products = initial_query(image, txt)
output_html = generate_output_html(top_3_products)
chatbot.append_message("You", txt)
chatbot.append_message("AI", output_html)
chatbot = gr.Chatbot([]).style(height=750)
txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False)
btn = gr.UploadButton("πŸ“", file_types=["image", "video", "audio"])
gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()
@app.post("/initial_query/")
async def api_initial_query(text: str, image: UploadFile = File(None)):
if image is not None:
image_content = await image.read()
image = Image.open(io.BytesIO(image_content))
image = preprocess_image(image)
else:
image = None
top_3_products = initial_query(image, text)
return {"top_3_products": top_3_products}