import os import clip import faiss import torch import tarfile import gradio as gr import pandas as pd from PIL import Image from braceexpand import braceexpand # Load model checkpoint_path = "ViT-B/16" device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load(checkpoint_path, device=device, jit=False) def generate_caption(img): # Load caption bank df = pd.read_parquet("files/captions.parquet") caption_list = df["caption"].tolist() # Load index index = faiss.read_index("files/caption_bank.index") # Encode the image and query the caption bank index query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) query_features = query_features / query_features.norm(dim=-1, keepdim=True) query_features = query_features.cpu().detach().numpy().astype("float32") # Get nearest captions d, i = index.search(query_features, 1) d, i = d[0], i[0] idx = i[0] distance = d[0] # Start with a description of the image caption = caption_list[idx] print(f"Index: {idx} - Distance: {distance:.2f}") return caption def predict_brand(img): # Load brand bank df = pd.read_parquet("files/brands.parquet") brand_list = df["brands"].tolist() # Load index index = faiss.read_index("files/brand_bank.index") # Encode the image and query the brand bank index query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) query_features = query_features / query_features.norm(dim=-1, keepdim=True) query_features = query_features.cpu().detach().numpy().astype("float32") # Get nearest brands d, i = index.search(query_features, 1) d, i = d[0], i[0] idx = i[0] distance = d[0] brand = brand_list[idx] print(f"Index: {idx} - Distance: {distance:.2f}") return brand def estimate_price_and_usage(img): return "Estimated price: 50-100 SEK - Usage: Reuse - Saved C02: 4 kg" def retrieve(query): index_folder = "files/index" num_results = 3 # Read image metadata metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet")) key_list = metadata_df["key"].tolist() # Load the index index = faiss.read_index(os.path.join(index_folder, "image.index")) # Encode the query if isinstance(query, str): print("Query is a string") text = clip.tokenize([query]).to(device) query_features = model.encode_text(text) else: print("Query is an image") query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device)) query_features = query_features / query_features.norm(dim=-1, keepdim=True) query_features = query_features.cpu().detach().numpy().astype("float32") d, i = index.search(query_features, num_results) print(f"Found {num_results} items with query '{query}'") indices = i[0] similarities = d[0] min_d = min(similarities) max_d = max(similarities) print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}") # Uncomment to generate combined.tar, combine the image_tars into a single tarfile """ dataset_dir = "/fs/sefs1/circularfashion/sellpy/front_balanced" image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{00000..00010}.tar"))] with tarfile.open("files/combined.tar", "w") as combined_tar: for tar in image_tars: with tarfile.open(tar, "r") as tar_file: for member in tar_file.getmembers(): combined_tar.addfile(member, tar_file.extractfile(member)) """ images = [] for idx in indices: image_name = key_list[idx] with tarfile.open("files/combined.tar", "r") as tar_file: image = tar_file.extractfile(f"{image_name}.jpg") image = Image.open(image).copy() images.append(image) return images theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty") with gr.Blocks( theme=theme, css="footer {visibility: hidden}", ) as demo: with gr.Tab("Captioning and Prediction"): with gr.Row(variant="compact"): input_img = gr.Image(type="pil", show_label=False) with gr.Column(min_width="80"): btn_generate_caption = gr.Button("Create Description").style(size="sm") generated_caption = gr.Textbox(label="Description", show_label=False) with gr.Row(variant="compact"): brand_img = gr.Image(type="pil", show_label=False) with gr.Column(min_width="80"): btn_predict_brand = gr.Button("Predict Brand").style(size="sm") predicted_brand = gr.Textbox(label="Brand", show_label=False) with gr.Column(variant="compact"): btn_estimate = gr.Button("Estimate Price, Reuse, and Saved C02").style(size="sm") text_box = gr.Textbox(label="Estimates:", show_label=False) with gr.Tab("Image Retrieval"): with gr.Row(variant="compact"): with gr.Column(): query_img = gr.Image(type="pil", label="Image Query") btn_image_query = gr.Button("Retrieve Garments").style(size="sm") img_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3) with gr.Row(variant="compact"): with gr.Column(): query_text = gr.Textbox(label="Text Query", placeholder="Enter a description") btn_text_query = gr.Button("Retrieve Garments").style(size="sm") text_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3) # Listeners btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption) btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand) btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box) btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery) btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery) if __name__ == "__main__": demo.launch( # inline=True )