style-demo / app.py
Plat
chore: update examples
89bb1fd
raw
history blame
3.25 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import TimmWrapperModel
import torch
import torchvision.transforms.v2 as T
MODEL_MAP = {
"p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
"mean": [0, 0, 0],
"std": [1.0, 1.0, 1.0],
"image_size": 384,
"background": 0,
}
}
def config_to_processor(config: dict):
return T.Compose(
[
T.PILToTensor(),
T.Resize(
size=None,
max_size=config["image_size"],
interpolation=T.InterpolationMode.NEAREST,
),
T.Pad(
padding=config["image_size"] // 2,
fill=config["background"],
),
T.CenterCrop(
size=(config["image_size"], config["image_size"]),
),
T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1
T.Normalize(mean=config["mean"], std=config["std"]),
]
)
def load_model(name: str):
return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False)
MODELS = {
name: {
"model": load_model(name),
"processor": config_to_processor(config),
}
for name, config in MODEL_MAP.items()
}
@torch.inference_mode()
def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image):
model = MODELS[model_name]["model"]
processor = MODELS[model_name]["processor"]
pixel_values = torch.stack([processor(image) for image in [image_1, image_2]])
embeddings = model(pixel_values).pooler_output
embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)
similarity = (embeddings[0] @ embeddings[1].T).item()
return similarity
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
image_1 = gr.Image(label="Image 1", type="pil")
image_2 = gr.Image(label="Image 2", type="pil")
gr.Examples(
examples=[
["./examples/sample_01.jpg", "./examples/sample_02.jpg"],
["./examples/sample_01.jpg", "./examples/sample_05.jpg"],
["./examples/sample_01.jpg", "./examples/sample_06.jpg"],
["./examples/sample_01.jpg", "./examples/sample_03.jpg"],
["./examples/sample_04.jpg", "./examples/sample_03.jpg"],
["./examples/sample_01.jpg", "./examples/sample_07.jpg"],
["./examples/sample_07.jpg", "./examples/sample_08.jpg"],
],
inputs=[image_1, image_2],
)
with gr.Column():
model_name = gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
)
submit_btn = gr.Button("Submit", variant="primary")
similarity = gr.Label(label="Similarity")
gr.on(
triggers=[submit_btn.click],
fn=calculate_similarity,
inputs=[
model_name,
image_1,
image_2,
],
outputs=[similarity],
)
if __name__ == "__main__":
demo.launch()