import gradio as gr import numpy as np from PIL import Image from transformers import TimmWrapperModel import torch import torchvision.transforms.v2 as T MODEL_MAP = { "p1atdev/style_250416.1.vit_base_patch16_siglip_384.v2_webli": { "mean": [0, 0, 0], "std": [1.0, 1.0, 1.0], "image_size": 384, "background": 0, }, "p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": { "mean": [0, 0, 0], "std": [1.0, 1.0, 1.0], "image_size": 384, "background": 0, }, } def config_to_processor(config: dict): return T.Compose( [ T.PILToTensor(), T.Resize( size=None, max_size=config["image_size"], interpolation=T.InterpolationMode.NEAREST, ), T.Pad( padding=config["image_size"] // 2, fill=config["background"], ), T.CenterCrop( size=(config["image_size"], config["image_size"]), ), T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1 T.Normalize(mean=config["mean"], std=config["std"]), ] ) def load_model(name: str): return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False) MODELS = { name: { "model": load_model(name), "processor": config_to_processor(config), } for name, config in MODEL_MAP.items() } @torch.inference_mode() def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image): model = MODELS[model_name]["model"] processor = MODELS[model_name]["processor"] pixel_values = torch.stack([processor(image) for image in [image_1, image_2]]) embeddings = model(pixel_values).pooler_output embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True) similarity = (embeddings[0] @ embeddings[1].T).item() return similarity with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=2): with gr.Row(): image_1 = gr.Image(label="Image 1", type="pil") image_2 = gr.Image(label="Image 2", type="pil") gr.Examples( examples=[ ["./examples/sample_01.jpg", "./examples/sample_02.jpg"], ["./examples/sample_01.jpg", "./examples/sample_05.jpg"], ["./examples/sample_01.jpg", "./examples/sample_06.jpg"], ["./examples/sample_01.jpg", "./examples/sample_03.jpg"], ["./examples/sample_04.jpg", "./examples/sample_03.jpg"], ["./examples/sample_01.jpg", "./examples/sample_07.jpg"], ["./examples/sample_07.jpg", "./examples/sample_08.jpg"], ], inputs=[image_1, image_2], ) with gr.Column(): model_name = gr.Dropdown( label="Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0], ) submit_btn = gr.Button("Submit", variant="primary") similarity = gr.Label(label="Similarity") gr.on( triggers=[submit_btn.click], fn=calculate_similarity, inputs=[ model_name, image_1, image_2, ], outputs=[similarity], ) if __name__ == "__main__": demo.launch()