Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
from transformers import TimmWrapper | |
import torch | |
import torchvision.transform.v2 as T | |
MODEL_MAP = { | |
"hf_hub:p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": { | |
"mean": [0, 0, 0], | |
"std": [1.0, 1.0, 1.0], | |
"image_size": 384, | |
"background": 0, | |
} | |
} | |
def config_to_processor(config: dict): | |
return T.Compose( | |
[ | |
T.Resize( | |
size=None, | |
max_size=config["image_size"], | |
interpolation=T.InterpolationMode.NEAREST, | |
), | |
T.Pad( | |
padding=config["image_size"] // 2, | |
fill=config["background]", # black | |
), | |
T.CenterCrop( | |
size=(config["image_size"], config["image_size"]), | |
), | |
T.PILToTensor(), | |
T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1 | |
T.Normalize(mean=config["mean"], std=config["std"]), | |
] | |
) | |
def load_model(name: str): | |
return TimmWrapper.from_pretrained(name).eval().requires_grad_False) | |
MODELS = { | |
name: { | |
"model": load_model(name), | |
"processor": config_to_processor(config), | |
} | |
for name, config in MODEL_NAMES.items() | |
} | |
def calculate_similarity(model:_name str, image_1: Image.Image, image_2: Image.Image): | |
model = MODELS[model_name]["model"] | |
processor = MODELS[model_name]["processor"] | |
pixel_values = torch.cat([ | |
processor(image) for image in [image_1, image_2] | |
]) | |
embeddings = model(pixel_values) | |
embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True) | |
similarity = (embeddings[0] * embeddings[1]).item() | |
return similarity | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image_1 = gr.Image("Image 1", type="pil") | |
image_2 = gr.Image("Image 2", type="pil") | |
model_name = gr.Dropdwon("Model", choices=list(MODELS.keys()) | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
similarity = gr.Text("Similarity") | |
gr.on( | |
triggers=[submit_btn.click], | |
fn=calculate_similarity, | |
inputs=[ | |
model_name, | |
image_1, | |
image_2, | |
], | |
outputs=[image_2], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |