Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import torch | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from tokenizers import Tokenizer | |
from torch.utils.data import Dataset | |
import albumentations as A | |
from tqdm import tqdm | |
from huggingface_hub import hf_hub_download | |
from datasets import load_dataset | |
from fourm.vq.vqvae import VQVAE | |
from fourm.models.fm import FM | |
from fourm.models.generate import ( | |
GenerationSampler, | |
build_chained_generation_schedules, | |
init_empty_target_modality, | |
custom_text, | |
) | |
from fourm.utils.plotting_utils import decode_dict | |
from fourm.data.modality_info import MODALITY_INFO | |
from fourm.data.modality_transforms import RGBTransform | |
from torchvision.transforms.functional import center_crop | |
# Constants and configurations | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
IMG_SIZE = 224 | |
TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json" | |
FM_MODEL_PATH = "EPFL-VILAB/4M-21_L" | |
VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224" | |
IMAGE_DATASET_PATH = "./data" | |
# Load models | |
text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH) | |
vqvae = VQVAE.from_pretrained(VQVAE_PATH) | |
fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE) | |
# Generation configurations | |
cond_domains = ["caption", "metadata"] | |
target_domains = ["tok_dinov2_global"] | |
tokens_per_target = [16] | |
generation_config = { | |
"autoregression_schemes": ["roar"], | |
"decoding_steps": [1], | |
"token_decoding_schedules": ["linear"], | |
"temps": [2.0], | |
"temp_schedules": ["onex:0.5:0.5"], | |
"cfg_scales": [1.0], | |
"cfg_schedules": ["constant"], | |
"cfg_grow_conditioning": True, | |
} | |
top_p, top_k = 0.8, 0.0 | |
schedule = build_chained_generation_schedules( | |
cond_domains=cond_domains, | |
target_domains=target_domains, | |
tokens_per_target=tokens_per_target, | |
**generation_config, | |
) | |
sampler = GenerationSampler(fm_model) | |
class HuggingFaceImageDataset(Dataset): | |
def __init__(self, dataset_name, split="train", img_sz=224): | |
self.dataset = load_dataset(dataset_name, split=split) | |
self.tfms = A.Compose([ | |
A.SmallestMaxSize(img_sz) | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
img = self.dataset[idx]['image'] | |
img = np.array(img) | |
img = self.tfms(image=img)["image"] | |
return Image.fromarray(img) | |
# Usage | |
dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo") | |
def load_image_embeddings(): | |
# Download the file | |
file_path = hf_hub_download(repo_id="aroraaman/img-tensor", filename="image_emb.pt") | |
# Load the tensor | |
image_embeddings = torch.load(file_path) | |
return image_embeddings | |
# Use the embeddings in your app | |
image_embeddings = load_image_embeddings() | |
image_embeddings.shape | |
print(image_embeddings.shape) | |
def get_similar_images(caption, brightness, num_items): | |
batched_sample = {} | |
for target_mod, ntoks in zip(target_domains, tokens_per_target): | |
batched_sample = init_empty_target_modality( | |
batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE | |
) | |
metadata = f"v1=6 v0={num_items} v1=10 v0={brightness}" | |
print(metadata) | |
batched_sample = custom_text( | |
batched_sample, | |
input_text=caption, | |
eos_token="[EOS]", | |
key="caption", | |
device=DEVICE, | |
text_tokenizer=text_tokenizer, | |
) | |
batched_sample = custom_text( | |
batched_sample, | |
input_text=metadata, | |
eos_token="[EOS]", | |
key="metadata", | |
device=DEVICE, | |
text_tokenizer=text_tokenizer, | |
) | |
out_dict = sampler.generate( | |
batched_sample, | |
schedule, | |
text_tokenizer=text_tokenizer, | |
verbose=True, | |
seed=0, | |
top_p=top_p, | |
top_k=top_k, | |
) | |
with torch.no_grad(): | |
dec_dict = decode_dict( | |
out_dict, | |
{"tok_dinov2_global": vqvae.to(DEVICE)}, | |
text_tokenizer, | |
image_size=IMG_SIZE, | |
patch_size=16, | |
decoding_steps=1, | |
) | |
combined_features = dec_dict["tok_dinov2_global"] | |
similarities = torch.nn.functional.cosine_similarity( | |
combined_features, image_embeddings | |
) | |
top_indices = similarities.argsort(descending=True)[:1] | |
print(top_indices, similarities[top_indices]) | |
return [dataset[i] for i in top_indices.cpu().numpy()] | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Retrieval using 4M-21: An Any-to-Any Vision Model") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
caption = gr.Textbox( | |
label="Caption Description", placeholder="Enter image description..." | |
) | |
brightness = gr.Slider( | |
minimum=0, maximum=255, value=5, step=1, | |
label="Brightness", info="Adjust image brightness (0-255)" | |
) | |
num_items = gr.Slider( | |
minimum=0, maximum=50, value=5, step=1, | |
label="Number of Items", info="Number of COCO instances in image (0-50)" | |
) | |
with gr.Column(scale=1): | |
output_images = gr.Gallery( | |
label="Retrieved Images", | |
show_label=True, | |
elem_id="gallery", | |
columns=2, | |
rows=2, | |
height=512, | |
) | |
submit_btn = gr.Button("Retrieve Most Similar Image") | |
submit_btn.click( | |
fn=get_similar_images, | |
inputs=[caption, brightness, num_items], | |
outputs=output_images, | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |