Spaces:
Running
Running
File size: 3,254 Bytes
7d6ed9c 3e0e479 7d6ed9c 3e0e479 7d6ed9c 5bffca4 3e0e479 5bffca4 7d6ed9c 3e0e479 5bffca4 3e0e479 5bffca4 3e0e479 5bffca4 3e0e479 5bffca4 7d6ed9c 3e0e479 5bffca4 3e0e479 7d6ed9c 5bffca4 3e0e479 5bffca4 7d6ed9c 5bffca4 3e0e479 5bffca4 3e0e479 5bffca4 3e0e479 5bffca4 3e0e479 5bffca4 89bb1fd 5bffca4 89bb1fd 3e0e479 5bffca4 3e0e479 7d6ed9c 5bffca4 7d6ed9c 5bffca4 7d6ed9c 3e0e479 7d6ed9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import numpy as np
from PIL import Image
from transformers import TimmWrapperModel
import torch
import torchvision.transforms.v2 as T
MODEL_MAP = {
"p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
"mean": [0, 0, 0],
"std": [1.0, 1.0, 1.0],
"image_size": 384,
"background": 0,
}
}
def config_to_processor(config: dict):
return T.Compose(
[
T.PILToTensor(),
T.Resize(
size=None,
max_size=config["image_size"],
interpolation=T.InterpolationMode.NEAREST,
),
T.Pad(
padding=config["image_size"] // 2,
fill=config["background"],
),
T.CenterCrop(
size=(config["image_size"], config["image_size"]),
),
T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1
T.Normalize(mean=config["mean"], std=config["std"]),
]
)
def load_model(name: str):
return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False)
MODELS = {
name: {
"model": load_model(name),
"processor": config_to_processor(config),
}
for name, config in MODEL_MAP.items()
}
@torch.inference_mode()
def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image):
model = MODELS[model_name]["model"]
processor = MODELS[model_name]["processor"]
pixel_values = torch.stack([processor(image) for image in [image_1, image_2]])
embeddings = model(pixel_values).pooler_output
embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)
similarity = (embeddings[0] @ embeddings[1].T).item()
return similarity
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
image_1 = gr.Image(label="Image 1", type="pil")
image_2 = gr.Image(label="Image 2", type="pil")
gr.Examples(
examples=[
["./examples/sample_01.jpg", "./examples/sample_02.jpg"],
["./examples/sample_01.jpg", "./examples/sample_05.jpg"],
["./examples/sample_01.jpg", "./examples/sample_06.jpg"],
["./examples/sample_01.jpg", "./examples/sample_03.jpg"],
["./examples/sample_04.jpg", "./examples/sample_03.jpg"],
["./examples/sample_01.jpg", "./examples/sample_07.jpg"],
["./examples/sample_07.jpg", "./examples/sample_08.jpg"],
],
inputs=[image_1, image_2],
)
with gr.Column():
model_name = gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
)
submit_btn = gr.Button("Submit", variant="primary")
similarity = gr.Label(label="Similarity")
gr.on(
triggers=[submit_btn.click],
fn=calculate_similarity,
inputs=[
model_name,
image_1,
image_2,
],
outputs=[similarity],
)
if __name__ == "__main__":
demo.launch()
|