|
import os
|
|
|
|
import clip
|
|
import faiss
|
|
import torch
|
|
import tarfile
|
|
import gradio as gr
|
|
import pandas as pd
|
|
from PIL import Image
|
|
from braceexpand import braceexpand
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
checkpoint_path = "ViT-B/16"
|
|
device = "cpu"
|
|
model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
|
|
|
|
|
|
def generate_caption(img):
|
|
|
|
df = pd.read_parquet("files/captions.parquet")
|
|
caption_list = df["caption"].tolist()
|
|
|
|
|
|
index = faiss.read_index("files/caption_bank.index")
|
|
|
|
|
|
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
|
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
|
query_features = query_features.cpu().detach().numpy().astype("float32")
|
|
|
|
|
|
d, i = index.search(query_features, 1)
|
|
d, i = d[0], i[0]
|
|
idx = i[0]
|
|
distance = d[0]
|
|
|
|
|
|
caption = caption_list[idx]
|
|
|
|
print(f"Index: {idx} - Distance: {distance:.2f}")
|
|
return caption
|
|
|
|
|
|
def predict_brand(img):
|
|
|
|
df = pd.read_parquet("files/brands.parquet")
|
|
brand_list = df["brands"].tolist()
|
|
|
|
|
|
index = faiss.read_index("files/brand_bank.index")
|
|
|
|
|
|
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
|
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
|
query_features = query_features.cpu().detach().numpy().astype("float32")
|
|
|
|
|
|
d, i = index.search(query_features, 1)
|
|
d, i = d[0], i[0]
|
|
idx = i[0]
|
|
distance = d[0]
|
|
|
|
brand = brand_list[idx]
|
|
print(f"Index: {idx} - Distance: {distance:.2f}")
|
|
return brand
|
|
|
|
|
|
def estimate_price_and_usage(img):
|
|
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
|
|
|
|
|
num_classes = 2
|
|
probe = torch.nn.Linear(
|
|
query_features.shape[-1],
|
|
num_classes,
|
|
dtype=torch.float16,
|
|
bias=False
|
|
)
|
|
|
|
linear_data = torch.load("files/reuse_linear.pth", map_location="cpu")
|
|
probe.weight.data = linear_data["weight"]
|
|
|
|
|
|
with torch.autocast("cpu"):
|
|
probe.eval()
|
|
probe = probe.to(device)
|
|
output = probe(query_features)
|
|
output = torch.softmax(output, dim=-1)
|
|
|
|
reuse = output.argmax(axis=-1)[0]
|
|
reuse_classes = ["Reuse", "Export"]
|
|
|
|
|
|
num_classes = 4
|
|
probe = torch.nn.Linear(
|
|
query_features.shape[-1],
|
|
num_classes,
|
|
dtype=torch.float16,
|
|
bias=False
|
|
)
|
|
|
|
|
|
linear_data = torch.load("files/price_linear.pth", map_location="cpu")
|
|
probe.weight.data = linear_data["weight"]
|
|
|
|
|
|
with torch.autocast("cpu"):
|
|
probe.eval()
|
|
probe = probe.to(device)
|
|
output = probe(query_features)
|
|
output = torch.softmax(output, dim=-1)
|
|
|
|
price = output.argmax(axis=-1)[0]
|
|
price_classes = ["<50", "50-100", "100-150", ">150"]
|
|
|
|
return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
|
|
|
|
|
|
def retrieve(query):
|
|
index_folder = "files/index"
|
|
num_results = 3
|
|
|
|
|
|
metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet"))
|
|
key_list = metadata_df["key"].tolist()
|
|
|
|
|
|
index = faiss.read_index(os.path.join(index_folder, "image.index"))
|
|
|
|
|
|
if isinstance(query, str):
|
|
print("Query is a string")
|
|
text = clip.tokenize([query]).to(device)
|
|
query_features = model.encode_text(text)
|
|
else:
|
|
print("Query is an image")
|
|
query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device))
|
|
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
|
query_features = query_features.cpu().detach().numpy().astype("float32")
|
|
|
|
d, i = index.search(query_features, num_results)
|
|
print(f"Found {num_results} items with query '{query}'")
|
|
indices = i[0]
|
|
similarities = d[0]
|
|
|
|
min_d = min(similarities)
|
|
max_d = max(similarities)
|
|
print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}")
|
|
|
|
|
|
"""
|
|
dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only"
|
|
image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))]
|
|
with tarfile.open("files/combined.tar", "w") as combined_tar:
|
|
for tar in image_tars:
|
|
with tarfile.open(tar, "r") as tar_file:
|
|
for member in tar_file.getmembers():
|
|
combined_tar.addfile(member, tar_file.extractfile(member))
|
|
"""
|
|
|
|
images = []
|
|
for idx in indices:
|
|
image_name = key_list[idx]
|
|
with tarfile.open("files/combined.tar", "r") as tar_file:
|
|
image = tar_file.extractfile(f"{image_name}.jpg")
|
|
image = Image.open(image).copy()
|
|
|
|
width, height = image.size
|
|
new_size = min(width, height)
|
|
image = transforms.CenterCrop(new_size)(image)
|
|
|
|
image = transforms.Resize((600, 600))(image)
|
|
images.append(image)
|
|
return images
|
|
|
|
|
|
theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
|
|
|
|
with gr.Blocks(
|
|
theme=theme,
|
|
css="footer {visibility: hidden}",
|
|
) as demo:
|
|
with gr.Tab("Captioning and Prediction"):
|
|
with gr.Row(variant="compact"):
|
|
input_img = gr.Image(type="pil", show_label=False)
|
|
with gr.Column(min_width="80"):
|
|
btn_generate_caption = gr.Button("Create Description").style(size="sm")
|
|
generated_caption = gr.Textbox(label="Description", show_label=False)
|
|
gr.Examples(
|
|
examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"],
|
|
fn=generate_caption,
|
|
inputs=input_img,
|
|
outputs=generated_caption
|
|
)
|
|
|
|
with gr.Row(variant="compact"):
|
|
brand_img = gr.Image(type="pil", show_label=False)
|
|
with gr.Column(min_width="80"):
|
|
btn_predict_brand = gr.Button("Predict Brand").style(size="sm")
|
|
predicted_brand = gr.Textbox(label="Brand", show_label=False)
|
|
gr.Examples(
|
|
examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"],
|
|
fn=predict_brand,
|
|
inputs=brand_img,
|
|
outputs=predicted_brand
|
|
)
|
|
|
|
with gr.Column(variant="compact"):
|
|
btn_estimate = gr.Button("Estimate Price and Reuse").style(size="sm")
|
|
text_box = gr.Textbox(label="Estimates:", show_label=False)
|
|
|
|
with gr.Tab("Image Retrieval"):
|
|
with gr.Row(variant="compact"):
|
|
with gr.Column():
|
|
query_img = gr.Image(type="pil", label="Image Query")
|
|
btn_image_query = gr.Button("Retrieve Garments").style(size="sm")
|
|
img_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3)
|
|
gr.Examples(
|
|
examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"],
|
|
fn=retrieve,
|
|
inputs=query_img,
|
|
outputs=img_query_gallery
|
|
)
|
|
|
|
with gr.Row(variant="compact"):
|
|
with gr.Column():
|
|
query_text = gr.Textbox(label="Text Query", placeholder="Enter a description")
|
|
btn_text_query = gr.Button("Retrieve Garments").style(size="sm")
|
|
text_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3)
|
|
gr.Examples(
|
|
examples=["A purple sweater", "A dress with a floral pattern"],
|
|
fn=retrieve,
|
|
inputs=query_text,
|
|
outputs=text_query_gallery
|
|
)
|
|
|
|
|
|
btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
|
|
btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
|
|
btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
|
|
btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery)
|
|
btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(
|
|
|
|
)
|
|
|