from os import getenv from typing import Optional import gradio as gr import torch from PIL import Image from torchvision.transforms import v2 as T from dreamsim import DreamsimBackbone, DreamsimEnsemble, DreamsimModel _ = torch.set_grad_enabled(False) torchdev = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_float32_matmul_precision("high") HF_TOKEN = getenv("HF_TOKEN", None) MODEL_REPO = "neggles/dreamsim" MODEL_VARIANTS: dict[str, str] = { "Ensemble": "ensemble_vitb16", "CLIP ViT-B/32": "clip_vitb32", "OpenCLIP ViT-B/32": "open_clip_vitb32", "DINO ViT-B/16": "dino_vitb16", } loaded_models: dict[str, Optional[DreamsimBackbone]] = { "ensemble_vitb16": None, "clip_vitb32": None, "open_clip_vitb32": None, "dino_vitb16": None, } def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in else image.convert("RGB") # convert RGBA to RGB with white background if image.mode == "RGBA": canvas ="RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square( image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255), ) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas ="RGB", (px, px), fill) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas def load_model(variant: str) -> DreamsimBackbone: global loaded_models if variant in MODEL_VARIANTS: # resolve the repo branch for the model variant variant = MODEL_VARIANTS[variant] match variant: case "ensemble_vitb16": if loaded_models[variant] is None: model: DreamsimEnsemble = DreamsimEnsemble.from_pretrained( MODEL_REPO, token=HF_TOKEN, revision=variant, ) model.do_resize = False loaded_models[variant] = model case "clip_vitb32" | "open_clip_vitb32" | "dino_vitb16": if loaded_models[variant] is None: model: DreamsimModel = DreamsimModel.from_pretrained( MODEL_REPO, token=HF_TOKEN, revision=variant, ) model.do_resize = False loaded_models[variant] = model case _: raise ValueError(f"Unknown model variant: {variant}") return loaded_models[variant] def predict( variant: str, resize_to: Optional[int], image_a: Image.Image, image_b: Image.Image, ): # Load model model: DreamsimModel | DreamsimEnsemble = load_model(variant) model = model.eval().to(torchdev) # yeet alpha, make white background image_a, image_b = pil_ensure_rgb(image_a), pil_ensure_rgb(image_b) # pad to square image_a, image_b = pil_pad_square(image_a), pil_pad_square(image_b) # Resize images, if necessary if resize_to is not None: image_a.thumbnail((resize_to, resize_to), resample=Image.Resampling.BICUBIC) image_b.thumbnail((resize_to, resize_to), resample=Image.Resampling.BICUBIC) # Preprocess images transforms = T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True)]) batch = torch.stack([transforms(image_a).unsqueeze(0), transforms(image_b).unsqueeze(0)], dim=0) loss = model(, model.dtype)).cpu().item() score = 1.0 - loss return score, variant def main(): with gr.Blocks(title="DreamSIM Perceptual Similarity") as demo: with gr.Row(): with gr.Column(): img_input = gr.Image(label="Input", type="pil", image_mode="RGB", scale=1) with gr.Column(): img_target = gr.Image(label="Target", type="pil", image_mode="RGB", scale=1) with gr.Row(equal_height=True): with gr.Column(): variant = gr.Radio( choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="Ensemble" ) resize_to = gr.Dropdown(label="Resize To", choices=[224, 384, 512, None], value=224) with gr.Column(): score = gr.Number(label="Similarity Score", precision=8, minimum=0, maximum=1) variant_out = gr.Textbox(label="Variant", interactive=False) with gr.Row(): clear = gr.ClearButton( components=[img_input, img_target, score], variant="secondary", size="lg" ) submit = gr.Button(value="Submit", variant="primary", size="lg") predict, inputs=[variant, resize_to, img_input, img_target], outputs=[score, variant_out], api_name=False, ) examples = gr.Examples( [ ["examples/img_a_1.png", "examples/ref_1.png", "Ensemble", 224], ["examples/img_b_1.png", "examples/ref_1.png", "Ensemble", 224], ], inputs=[img_input, img_target, variant, resize_to], ) demo.queue(max_size=10) demo.launch() if __name__ == "__main__": main()