style-demo / app.py
p1atdev's picture
Update app.py
5bffca4 verified
raw
history blame
2.41 kB
import gradio as gr
import numpy as np
from transformers import TimmWrapper
import torch
import torchvision.transform.v2 as T
MODEL_MAP = {
"hf_hub:p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
"mean": [0, 0, 0],
"std": [1.0, 1.0, 1.0],
"image_size": 384,
"background": 0,
}
}
def config_to_processor(config: dict):
return T.Compose(
[
T.Resize(
size=None,
max_size=config["image_size"],
interpolation=T.InterpolationMode.NEAREST,
),
T.Pad(
padding=config["image_size"] // 2,
fill=config["background]", # black
),
T.CenterCrop(
size=(config["image_size"], config["image_size"]),
),
T.PILToTensor(),
T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1
T.Normalize(mean=config["mean"], std=config["std"]),
]
)
def load_model(name: str):
return TimmWrapper.from_pretrained(name).eval().requires_grad_False)
MODELS = {
name: {
"model": load_model(name),
"processor": config_to_processor(config),
}
for name, config in MODEL_NAMES.items()
}
@torch.inference_mode()
def calculate_similarity(model:_name str, image_1: Image.Image, image_2: Image.Image):
model = MODELS[model_name]["model"]
processor = MODELS[model_name]["processor"]
pixel_values = torch.cat([
processor(image) for image in [image_1, image_2]
])
embeddings = model(pixel_values)
embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)
similarity = (embeddings[0] * embeddings[1]).item()
return similarity
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_1 = gr.Image("Image 1", type="pil")
image_2 = gr.Image("Image 2", type="pil")
model_name = gr.Dropdwon("Model", choices=list(MODELS.keys())
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
similarity = gr.Text("Similarity")
gr.on(
triggers=[submit_btn.click],
fn=calculate_similarity,
inputs=[
model_name,
image_1,
image_2,
],
outputs=[image_2],
)
if __name__ == "__main__":
demo.launch()