Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from transformers import TimmWrapperModel | |
import torch | |
import torchvision.transforms.v2 as T | |
MODEL_MAP = { | |
"p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": { | |
"mean": [0, 0, 0], | |
"std": [1.0, 1.0, 1.0], | |
"image_size": 384, | |
"background": 0, | |
} | |
} | |
def config_to_processor(config: dict): | |
return T.Compose( | |
[ | |
T.PILToTensor(), | |
T.Resize( | |
size=None, | |
max_size=config["image_size"], | |
interpolation=T.InterpolationMode.NEAREST, | |
), | |
T.Pad( | |
padding=config["image_size"] // 2, | |
fill=config["background"], | |
), | |
T.CenterCrop( | |
size=(config["image_size"], config["image_size"]), | |
), | |
T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1 | |
T.Normalize(mean=config["mean"], std=config["std"]), | |
] | |
) | |
def load_model(name: str): | |
return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False) | |
MODELS = { | |
name: { | |
"model": load_model(name), | |
"processor": config_to_processor(config), | |
} | |
for name, config in MODEL_MAP.items() | |
} | |
def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image): | |
model = MODELS[model_name]["model"] | |
processor = MODELS[model_name]["processor"] | |
pixel_values = torch.stack([processor(image) for image in [image_1, image_2]]) | |
embeddings = model(pixel_values).pooler_output | |
embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True) | |
similarity = (embeddings[0] @ embeddings[1].T).item() | |
return similarity | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Row(): | |
image_1 = gr.Image(label="Image 1", type="pil") | |
image_2 = gr.Image(label="Image 2", type="pil") | |
gr.Examples( | |
examples=[ | |
["./examples/sample_01.jpg", "./examples/sample_02.jpg"], | |
["./examples/sample_01.jpg", "./examples/sample_05.jpg"], | |
["./examples/sample_01.jpg", "./examples/sample_06.jpg"], | |
["./examples/sample_01.jpg", "./examples/sample_03.jpg"], | |
["./examples/sample_04.jpg", "./examples/sample_03.jpg"], | |
["./examples/sample_01.jpg", "./examples/sample_07.jpg"], | |
["./examples/sample_07.jpg", "./examples/sample_08.jpg"], | |
], | |
inputs=[image_1, image_2], | |
) | |
with gr.Column(): | |
model_name = gr.Dropdown( | |
label="Model", | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
similarity = gr.Label(label="Similarity") | |
gr.on( | |
triggers=[submit_btn.click], | |
fn=calculate_similarity, | |
inputs=[ | |
model_name, | |
image_1, | |
image_2, | |
], | |
outputs=[similarity], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |