File size: 2,406 Bytes
7d6ed9c
 
5bffca4
7d6ed9c
 
5bffca4
7d6ed9c
 
5bffca4
 
 
 
 
 
 
7d6ed9c
 
5bffca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6ed9c
5bffca4
 
7d6ed9c
5bffca4
 
 
 
 
 
 
7d6ed9c
 
5bffca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6ed9c
 
5bffca4
 
7d6ed9c
5bffca4
 
 
7d6ed9c
5bffca4
7d6ed9c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import numpy as np
from transformers import TimmWrapper 

import torch
import torchvision.transform.v2 as T


MODEL_MAP = {
    "hf_hub:p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
        "mean": [0, 0, 0],
        "std": [1.0, 1.0, 1.0],
        "image_size": 384,
        "background": 0,
    }
}

def config_to_processor(config: dict):
    return T.Compose(
        [
            T.Resize(
                size=None,
                max_size=config["image_size"],
                interpolation=T.InterpolationMode.NEAREST,
            ),
            T.Pad(
                padding=config["image_size"] // 2,
                fill=config["background]",  # black
            ),
            T.CenterCrop(
                size=(config["image_size"], config["image_size"]),
            ),
            T.PILToTensor(),
            T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1
            T.Normalize(mean=config["mean"], std=config["std"]),
        ]
    )

def load_model(name: str):
    return TimmWrapper.from_pretrained(name).eval().requires_grad_False)

MODELS = {
    name: {
        "model": load_model(name),
        "processor": config_to_processor(config),
    }
    for name, config in MODEL_NAMES.items()
}


@torch.inference_mode()
def calculate_similarity(model:_name str, image_1: Image.Image, image_2: Image.Image):
    model = MODELS[model_name]["model"]
    processor = MODELS[model_name]["processor"]
    
    pixel_values = torch.cat([
        processor(image) for image in [image_1, image_2]
    ])
    embeddings = model(pixel_values)
    embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)

    similarity = (embeddings[0] * embeddings[1]).item()
    return similarity
    

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image_1 = gr.Image("Image 1", type="pil")
            image_2 = gr.Image("Image 2", type="pil")

            model_name = gr.Dropdwon("Model", choices=list(MODELS.keys())
            submit_btn = gr.Button("Submit", variant="primary")
        
        with gr.Column():
            similarity = gr.Text("Similarity")

    gr.on(
        triggers=[submit_btn.click],
        fn=calculate_similarity,
        inputs=[
            model_name,
            image_1,
            image_2,
        ],
        outputs=[image_2],
    )

if __name__ == "__main__":
    demo.launch()