import os import clip import faiss import torch import tarfile import gradio as gr import pandas as pd from PIL import Image from braceexpand import braceexpand from torchvision import transforms # Load model checkpoint_path = "ViT-B/16" device = "cpu" model, preprocess = clip.load(checkpoint_path, device=device, jit=False) def generate_caption(img): # Load caption bank df = pd.read_parquet("files/captions.parquet") caption_list = df["caption"].tolist() # Load index index = faiss.read_index("files/caption_bank.index") # Encode the image and query the caption bank index query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) query_features = query_features / query_features.norm(dim=-1, keepdim=True) query_features = query_features.cpu().detach().numpy().astype("float32") # Get nearest captions d, i = index.search(query_features, 1) d, i = d[0], i[0] idx = i[0] distance = d[0] # Start with a description of the image caption = caption_list[idx] print(f"Index: {idx} - Distance: {distance:.2f}") return caption def predict_brand(img): # Load brand bank df = pd.read_parquet("files/brands.parquet") brand_list = df["brands"].tolist() # Load index index = faiss.read_index("files/brand_bank.index") # Encode the image and query the brand bank index query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) query_features = query_features / query_features.norm(dim=-1, keepdim=True) query_features = query_features.cpu().detach().numpy().astype("float32") # Get nearest brands d, i = index.search(query_features, 1) d, i = d[0], i[0] idx = i[0] distance = d[0] brand = brand_list[idx] print(f"Index: {idx} - Distance: {distance:.2f}") return brand def estimate_price_and_usage(img): query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) # Estimate usage num_classes = 2 probe = torch.nn.Linear( query_features.shape[-1], num_classes, dtype=torch.float16, bias=False ) # Load weights for the linear layer as a tensor linear_data = torch.load("files/reuse_linear.pth", map_location="cpu") probe.weight.data = linear_data["weight"] # Do inference with torch.autocast("cpu"): probe.eval() probe = probe.to(device) output = probe(query_features) output = torch.softmax(output, dim=-1) #output = output.cpu().detach().numpy().astype("float32") reuse = output.argmax(axis=-1)[0] reuse_classes = ["Reuse", "Export"] # Estimate price num_classes = 4 probe = torch.nn.Linear( query_features.shape[-1], num_classes, dtype=torch.float16, bias=False ) # Print output shape for the linear layer # Load weights for the linear layer as a tensor linear_data = torch.load("files/price_linear.pth", map_location="cpu") probe.weight.data = linear_data["weight"] # Do inference with torch.autocast("cpu"): probe.eval() probe = probe.to(device) output = probe(query_features) output = torch.softmax(output, dim=-1) #output = output.cpu().detach().numpy().astype("float32") price = output.argmax(axis=-1)[0] price_classes = ["<50", "50-100", "100-150", ">150"] return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}" def retrieve(query): index_folder = "files/index" num_results = 3 # Read image metadata metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet")) key_list = metadata_df["key"].tolist() # Load the index index = faiss.read_index(os.path.join(index_folder, "image.index")) # Encode the query if isinstance(query, str): print("Query is a string") text = clip.tokenize([query]).to(device) query_features = model.encode_text(text) else: print("Query is an image") query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device)) query_features = query_features / query_features.norm(dim=-1, keepdim=True) query_features = query_features.cpu().detach().numpy().astype("float32") d, i = index.search(query_features, num_results) print(f"Found {num_results} items with query '{query}'") indices = i[0] similarities = d[0] min_d = min(similarities) max_d = max(similarities) print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}") # Uncomment to generate combined.tar, combine the image_tars into a single tarfile """ dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only" image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))] with tarfile.open("files/combined.tar", "w") as combined_tar: for tar in image_tars: with tarfile.open(tar, "r") as tar_file: for member in tar_file.getmembers(): combined_tar.addfile(member, tar_file.extractfile(member)) """ images = [] for idx in indices: image_name = key_list[idx] with tarfile.open("files/combined.tar", "r") as tar_file: image = tar_file.extractfile(f"{image_name}.jpg") image = Image.open(image).copy() # Center crop the image width, height = image.size new_size = min(width, height) image = transforms.CenterCrop(new_size)(image) # Resize the image image = transforms.Resize((600, 600))(image) images.append(image) return images theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty") with gr.Blocks( theme=theme, css="footer {visibility: hidden}", ) as demo: with gr.Tab("Captioning and Prediction"): with gr.Row(variant="compact"): input_img = gr.Image(type="pil", show_label=False) with gr.Column(min_width="80"): btn_generate_caption = gr.Button("Create Description", size="sm") generated_caption = gr.Textbox(label="Description", show_label=False) gr.Examples( examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"], fn=generate_caption, inputs=input_img, outputs=generated_caption ) with gr.Row(variant="compact"): brand_img = gr.Image(type="pil", show_label=False) with gr.Column(min_width="80"): btn_predict_brand = gr.Button("Predict Brand", size="sm") predicted_brand = gr.Textbox(label="Brand", show_label=False) gr.Examples( examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"], fn=predict_brand, inputs=brand_img, outputs=predicted_brand ) with gr.Column(variant="compact"): btn_estimate = gr.Button("Estimate Price and Reuse", size="sm") text_box = gr.Textbox(label="Estimates:", show_label=False) with gr.Tab("Image Retrieval"): with gr.Row(variant="compact"): with gr.Column(): query_img = gr.Image(type="pil", label="Image Query") btn_image_query = gr.Button("Retrieve Garments", size="sm") img_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3) gr.Examples( examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"], fn=retrieve, inputs=query_img, outputs=img_query_gallery ) with gr.Row(variant="compact"): with gr.Column(): query_text = gr.Textbox(label="Text Query", placeholder="Enter a description") btn_text_query = gr.Button("Retrieve Garments", size="sm") text_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3) gr.Examples( examples=["A purple sweater", "A dress with a floral pattern"], fn=retrieve, inputs=query_text, outputs=text_query_gallery ) # Listeners btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption) btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand) btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box) btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery) btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery) if __name__ == "__main__": demo.launch( # inline=True )