File size: 3,254 Bytes
7d6ed9c
 
3e0e479
 
7d6ed9c
 
3e0e479
7d6ed9c
 
5bffca4
3e0e479
5bffca4
 
 
 
 
7d6ed9c
 
3e0e479
5bffca4
 
 
3e0e479
5bffca4
 
 
 
 
 
 
3e0e479
5bffca4
 
 
 
3e0e479
5bffca4
 
 
7d6ed9c
3e0e479
5bffca4
3e0e479
 
7d6ed9c
5bffca4
 
 
 
 
3e0e479
5bffca4
7d6ed9c
 
5bffca4
3e0e479
5bffca4
 
3e0e479
 
 
 
5bffca4
 
3e0e479
 
5bffca4
3e0e479
5bffca4
 
 
89bb1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bffca4
89bb1fd
3e0e479
 
 
 
 
5bffca4
3e0e479
7d6ed9c
 
5bffca4
 
7d6ed9c
5bffca4
 
 
7d6ed9c
3e0e479
7d6ed9c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import numpy as np
from PIL import Image
from transformers import TimmWrapperModel

import torch
import torchvision.transforms.v2 as T


MODEL_MAP = {
    "p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
        "mean": [0, 0, 0],
        "std": [1.0, 1.0, 1.0],
        "image_size": 384,
        "background": 0,
    }
}


def config_to_processor(config: dict):
    return T.Compose(
        [
            T.PILToTensor(),
            T.Resize(
                size=None,
                max_size=config["image_size"],
                interpolation=T.InterpolationMode.NEAREST,
            ),
            T.Pad(
                padding=config["image_size"] // 2,
                fill=config["background"],
            ),
            T.CenterCrop(
                size=(config["image_size"], config["image_size"]),
            ),
            T.ToDtype(dtype=torch.float32, scale=True),  # 0~255 -> 0~1
            T.Normalize(mean=config["mean"], std=config["std"]),
        ]
    )


def load_model(name: str):
    return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False)


MODELS = {
    name: {
        "model": load_model(name),
        "processor": config_to_processor(config),
    }
    for name, config in MODEL_MAP.items()
}


@torch.inference_mode()
def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image):
    model = MODELS[model_name]["model"]
    processor = MODELS[model_name]["processor"]

    pixel_values = torch.stack([processor(image) for image in [image_1, image_2]])

    embeddings = model(pixel_values).pooler_output
    embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)

    similarity = (embeddings[0] @ embeddings[1].T).item()

    return similarity


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=2):
            with gr.Row():
                image_1 = gr.Image(label="Image 1", type="pil")

                image_2 = gr.Image(label="Image 2", type="pil")

            gr.Examples(
                examples=[
                    ["./examples/sample_01.jpg", "./examples/sample_02.jpg"],
                    ["./examples/sample_01.jpg", "./examples/sample_05.jpg"],
                    ["./examples/sample_01.jpg", "./examples/sample_06.jpg"],
                    ["./examples/sample_01.jpg", "./examples/sample_03.jpg"],
                    ["./examples/sample_04.jpg", "./examples/sample_03.jpg"],
                    ["./examples/sample_01.jpg", "./examples/sample_07.jpg"],
                    ["./examples/sample_07.jpg", "./examples/sample_08.jpg"],
                ],
                inputs=[image_1, image_2],
            )

        with gr.Column():
            model_name = gr.Dropdown(
                label="Model",
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
            )
            submit_btn = gr.Button("Submit", variant="primary")
            similarity = gr.Label(label="Similarity")

    gr.on(
        triggers=[submit_btn.click],
        fn=calculate_similarity,
        inputs=[
            model_name,
            image_1,
            image_2,
        ],
        outputs=[similarity],
    )

if __name__ == "__main__":
    demo.launch()