File size: 2,537 Bytes
7d6ed9c
 
3e0e479
 
7d6ed9c
 
3e0e479
7d6ed9c
 
5bffca4
3e0e479
5bffca4
 
 
 
 
7d6ed9c
 
3e0e479
5bffca4
 
 
3e0e479
5bffca4
 
 
 
 
 
 
3e0e479
5bffca4
 
 
 
3e0e479
5bffca4
 
 
7d6ed9c
3e0e479
5bffca4
3e0e479
 
7d6ed9c
5bffca4
 
 
 
 
3e0e479
5bffca4
7d6ed9c
 
5bffca4
3e0e479
5bffca4
 
3e0e479
 
 
 
5bffca4
 
3e0e479
 
5bffca4
3e0e479
5bffca4
 
 
 
3e0e479
 
5bffca4
3e0e479
 
 
 
 
5bffca4
3e0e479
5bffca4
3e0e479
7d6ed9c
 
5bffca4
 
7d6ed9c
5bffca4
 
 
7d6ed9c
3e0e479
7d6ed9c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import numpy as np
from PIL import Image
from transformers import TimmWrapperModel

import torch
import torchvision.transforms.v2 as T


MODEL_MAP = {
    "p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
        "mean": [0, 0, 0],
        "std": [1.0, 1.0, 1.0],
        "image_size": 384,
        "background": 0,
    }
}


def config_to_processor(config: dict):
    return T.Compose(
        [
            T.PILToTensor(),
            T.Resize(
                size=None,
                max_size=config["image_size"],
                interpolation=T.InterpolationMode.NEAREST,
            ),
            T.Pad(
                padding=config["image_size"] // 2,
                fill=config["background"],
            ),
            T.CenterCrop(
                size=(config["image_size"], config["image_size"]),
            ),
            T.ToDtype(dtype=torch.float32, scale=True),  # 0~255 -> 0~1
            T.Normalize(mean=config["mean"], std=config["std"]),
        ]
    )


def load_model(name: str):
    return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False)


MODELS = {
    name: {
        "model": load_model(name),
        "processor": config_to_processor(config),
    }
    for name, config in MODEL_MAP.items()
}


@torch.inference_mode()
def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image):
    model = MODELS[model_name]["model"]
    processor = MODELS[model_name]["processor"]

    pixel_values = torch.stack([processor(image) for image in [image_1, image_2]])

    embeddings = model(pixel_values).pooler_output
    embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)

    similarity = (embeddings[0] @ embeddings[1].T).item()

    return similarity


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image_1 = gr.Image(label="Image 1", type="pil")
            image_2 = gr.Image(label="Image 2", type="pil")

            model_name = gr.Dropdown(
                label="Model",
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
            )
            submit_btn = gr.Button("Submit", variant="primary")

        with gr.Column():
            similarity = gr.Label(label="Similarity")

    gr.on(
        triggers=[submit_btn.click],
        fn=calculate_similarity,
        inputs=[
            model_name,
            image_1,
            image_2,
        ],
        outputs=[similarity],
    )

if __name__ == "__main__":
    demo.launch()