|
import os |
|
|
|
import clip |
|
import faiss |
|
import torch |
|
import tarfile |
|
import gradio as gr |
|
import pandas as pd |
|
from PIL import Image |
|
from braceexpand import braceexpand |
|
from torchvision import transforms |
|
|
|
|
|
|
|
checkpoint_path = "ViT-B/16" |
|
device = "cpu" |
|
model, preprocess = clip.load(checkpoint_path, device=device, jit=False) |
|
|
|
|
|
def generate_caption(img): |
|
|
|
df = pd.read_parquet("files/captions.parquet") |
|
caption_list = df["caption"].tolist() |
|
|
|
|
|
index = faiss.read_index("files/caption_bank.index") |
|
|
|
|
|
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) |
|
query_features = query_features / query_features.norm(dim=-1, keepdim=True) |
|
query_features = query_features.cpu().detach().numpy().astype("float32") |
|
|
|
|
|
d, i = index.search(query_features, 1) |
|
d, i = d[0], i[0] |
|
idx = i[0] |
|
distance = d[0] |
|
|
|
|
|
caption = caption_list[idx] |
|
|
|
print(f"Index: {idx} - Distance: {distance:.2f}") |
|
return caption |
|
|
|
|
|
def predict_brand(img): |
|
|
|
df = pd.read_parquet("files/brands.parquet") |
|
brand_list = df["brands"].tolist() |
|
|
|
|
|
index = faiss.read_index("files/brand_bank.index") |
|
|
|
|
|
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) |
|
query_features = query_features / query_features.norm(dim=-1, keepdim=True) |
|
query_features = query_features.cpu().detach().numpy().astype("float32") |
|
|
|
|
|
d, i = index.search(query_features, 1) |
|
d, i = d[0], i[0] |
|
idx = i[0] |
|
distance = d[0] |
|
|
|
brand = brand_list[idx] |
|
print(f"Index: {idx} - Distance: {distance:.2f}") |
|
return brand |
|
|
|
|
|
def estimate_price_and_usage(img): |
|
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device)) |
|
|
|
|
|
num_classes = 2 |
|
probe = torch.nn.Linear( |
|
query_features.shape[-1], |
|
num_classes, |
|
dtype=torch.float16, |
|
bias=False |
|
) |
|
|
|
linear_data = torch.load("files/reuse_linear.pth", map_location="cpu") |
|
probe.weight.data = linear_data["weight"] |
|
|
|
|
|
with torch.autocast("cpu"): |
|
probe.eval() |
|
probe = probe.to(device) |
|
output = probe(query_features) |
|
output = torch.softmax(output, dim=-1) |
|
|
|
reuse = output.argmax(axis=-1)[0] |
|
reuse_classes = ["Reuse", "Export"] |
|
|
|
|
|
num_classes = 4 |
|
probe = torch.nn.Linear( |
|
query_features.shape[-1], |
|
num_classes, |
|
dtype=torch.float16, |
|
bias=False |
|
) |
|
|
|
|
|
linear_data = torch.load("files/price_linear.pth", map_location="cpu") |
|
probe.weight.data = linear_data["weight"] |
|
|
|
|
|
with torch.autocast("cpu"): |
|
probe.eval() |
|
probe = probe.to(device) |
|
output = probe(query_features) |
|
output = torch.softmax(output, dim=-1) |
|
|
|
price = output.argmax(axis=-1)[0] |
|
price_classes = ["<50", "50-100", "100-150", ">150"] |
|
|
|
return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}" |
|
|
|
|
|
def retrieve(query): |
|
index_folder = "files/index" |
|
num_results = 3 |
|
|
|
|
|
metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet")) |
|
key_list = metadata_df["key"].tolist() |
|
|
|
|
|
index = faiss.read_index(os.path.join(index_folder, "image.index")) |
|
|
|
|
|
if isinstance(query, str): |
|
print("Query is a string") |
|
text = clip.tokenize([query]).to(device) |
|
query_features = model.encode_text(text) |
|
else: |
|
print("Query is an image") |
|
query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device)) |
|
query_features = query_features / query_features.norm(dim=-1, keepdim=True) |
|
query_features = query_features.cpu().detach().numpy().astype("float32") |
|
|
|
d, i = index.search(query_features, num_results) |
|
print(f"Found {num_results} items with query '{query}'") |
|
indices = i[0] |
|
similarities = d[0] |
|
|
|
min_d = min(similarities) |
|
max_d = max(similarities) |
|
print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}") |
|
|
|
|
|
""" |
|
dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only" |
|
image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))] |
|
with tarfile.open("files/combined.tar", "w") as combined_tar: |
|
for tar in image_tars: |
|
with tarfile.open(tar, "r") as tar_file: |
|
for member in tar_file.getmembers(): |
|
combined_tar.addfile(member, tar_file.extractfile(member)) |
|
""" |
|
|
|
images = [] |
|
for idx in indices: |
|
image_name = key_list[idx] |
|
with tarfile.open("files/combined.tar", "r") as tar_file: |
|
image = tar_file.extractfile(f"{image_name}.jpg") |
|
image = Image.open(image).copy() |
|
|
|
width, height = image.size |
|
new_size = min(width, height) |
|
image = transforms.CenterCrop(new_size)(image) |
|
|
|
image = transforms.Resize((600, 600))(image) |
|
images.append(image) |
|
return images |
|
|
|
|
|
theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty") |
|
|
|
with gr.Blocks( |
|
theme=theme, |
|
css="footer {visibility: hidden}", |
|
) as demo: |
|
with gr.Tab("Captioning and Prediction"): |
|
with gr.Row(variant="compact"): |
|
input_img = gr.Image(type="pil", show_label=False) |
|
with gr.Column(min_width="80"): |
|
btn_generate_caption = gr.Button("Create Description", size="sm") |
|
generated_caption = gr.Textbox(label="Description", show_label=False) |
|
gr.Examples( |
|
examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"], |
|
fn=generate_caption, |
|
inputs=input_img, |
|
outputs=generated_caption |
|
) |
|
|
|
with gr.Row(variant="compact"): |
|
brand_img = gr.Image(type="pil", show_label=False) |
|
with gr.Column(min_width="80"): |
|
btn_predict_brand = gr.Button("Predict Brand", size="sm") |
|
predicted_brand = gr.Textbox(label="Brand", show_label=False) |
|
gr.Examples( |
|
examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"], |
|
fn=predict_brand, |
|
inputs=brand_img, |
|
outputs=predicted_brand |
|
) |
|
|
|
with gr.Column(variant="compact"): |
|
btn_estimate = gr.Button("Estimate Price and Reuse", size="sm") |
|
text_box = gr.Textbox(label="Estimates:", show_label=False) |
|
|
|
with gr.Tab("Image Retrieval"): |
|
with gr.Row(variant="compact"): |
|
with gr.Column(): |
|
query_img = gr.Image(type="pil", label="Image Query") |
|
btn_image_query = gr.Button("Retrieve Garments", size="sm") |
|
img_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3) |
|
gr.Examples( |
|
examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"], |
|
fn=retrieve, |
|
inputs=query_img, |
|
outputs=img_query_gallery |
|
) |
|
|
|
with gr.Row(variant="compact"): |
|
with gr.Column(): |
|
query_text = gr.Textbox(label="Text Query", placeholder="Enter a description") |
|
btn_text_query = gr.Button("Retrieve Garments", size="sm") |
|
text_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3) |
|
gr.Examples( |
|
examples=["A purple sweater", "A dress with a floral pattern"], |
|
fn=retrieve, |
|
inputs=query_text, |
|
outputs=text_query_gallery |
|
) |
|
|
|
|
|
btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption) |
|
btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand) |
|
btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box) |
|
btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery) |
|
btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
|
|
) |
|
|