Plat commited on
Commit
3e0e479
·
1 Parent(s): 5bffca4
Files changed (6) hide show
  1. .gitignore +2 -0
  2. .python-version +1 -0
  3. app.py +30 -22
  4. pyproject.toml +20 -0
  5. requirements.txt +1 -0
  6. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ .mypy_cache
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import gradio as gr
2
  import numpy as np
3
- from transformers import TimmWrapper
 
4
 
5
  import torch
6
- import torchvision.transform.v2 as T
7
 
8
 
9
  MODEL_MAP = {
10
- "hf_hub:p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
11
  "mean": [0, 0, 0],
12
  "std": [1.0, 1.0, 1.0],
13
  "image_size": 384,
@@ -15,9 +16,11 @@ MODEL_MAP = {
15
  }
16
  }
17
 
 
18
  def config_to_processor(config: dict):
19
  return T.Compose(
20
  [
 
21
  T.Resize(
22
  size=None,
23
  max_size=config["image_size"],
@@ -25,55 +28,60 @@ def config_to_processor(config: dict):
25
  ),
26
  T.Pad(
27
  padding=config["image_size"] // 2,
28
- fill=config["background]", # black
29
  ),
30
  T.CenterCrop(
31
  size=(config["image_size"], config["image_size"]),
32
  ),
33
- T.PILToTensor(),
34
- T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1
35
  T.Normalize(mean=config["mean"], std=config["std"]),
36
  ]
37
  )
38
 
 
39
  def load_model(name: str):
40
- return TimmWrapper.from_pretrained(name).eval().requires_grad_False)
 
41
 
42
  MODELS = {
43
  name: {
44
  "model": load_model(name),
45
  "processor": config_to_processor(config),
46
  }
47
- for name, config in MODEL_NAMES.items()
48
  }
49
 
50
 
51
  @torch.inference_mode()
52
- def calculate_similarity(model:_name str, image_1: Image.Image, image_2: Image.Image):
53
  model = MODELS[model_name]["model"]
54
  processor = MODELS[model_name]["processor"]
55
-
56
- pixel_values = torch.cat([
57
- processor(image) for image in [image_1, image_2]
58
- ])
59
- embeddings = model(pixel_values)
60
  embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)
61
 
62
- similarity = (embeddings[0] * embeddings[1]).item()
 
63
  return similarity
64
-
65
 
66
  with gr.Blocks() as demo:
67
  with gr.Row():
68
  with gr.Column():
69
- image_1 = gr.Image("Image 1", type="pil")
70
- image_2 = gr.Image("Image 2", type="pil")
71
 
72
- model_name = gr.Dropdwon("Model", choices=list(MODELS.keys())
 
 
 
 
73
  submit_btn = gr.Button("Submit", variant="primary")
74
-
75
  with gr.Column():
76
- similarity = gr.Text("Similarity")
77
 
78
  gr.on(
79
  triggers=[submit_btn.click],
@@ -83,7 +91,7 @@ with gr.Blocks() as demo:
83
  image_1,
84
  image_2,
85
  ],
86
- outputs=[image_2],
87
  )
88
 
89
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
3
+ from PIL import Image
4
+ from transformers import TimmWrapperModel
5
 
6
  import torch
7
+ import torchvision.transforms.v2 as T
8
 
9
 
10
  MODEL_MAP = {
11
+ "p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": {
12
  "mean": [0, 0, 0],
13
  "std": [1.0, 1.0, 1.0],
14
  "image_size": 384,
 
16
  }
17
  }
18
 
19
+
20
  def config_to_processor(config: dict):
21
  return T.Compose(
22
  [
23
+ T.PILToTensor(),
24
  T.Resize(
25
  size=None,
26
  max_size=config["image_size"],
 
28
  ),
29
  T.Pad(
30
  padding=config["image_size"] // 2,
31
+ fill=config["background"],
32
  ),
33
  T.CenterCrop(
34
  size=(config["image_size"], config["image_size"]),
35
  ),
36
+ T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1
 
37
  T.Normalize(mean=config["mean"], std=config["std"]),
38
  ]
39
  )
40
 
41
+
42
  def load_model(name: str):
43
+ return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False)
44
+
45
 
46
  MODELS = {
47
  name: {
48
  "model": load_model(name),
49
  "processor": config_to_processor(config),
50
  }
51
+ for name, config in MODEL_MAP.items()
52
  }
53
 
54
 
55
  @torch.inference_mode()
56
+ def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image):
57
  model = MODELS[model_name]["model"]
58
  processor = MODELS[model_name]["processor"]
59
+
60
+ pixel_values = torch.stack([processor(image) for image in [image_1, image_2]])
61
+
62
+ embeddings = model(pixel_values).pooler_output
 
63
  embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True)
64
 
65
+ similarity = (embeddings[0] @ embeddings[1].T).item()
66
+
67
  return similarity
68
+
69
 
70
  with gr.Blocks() as demo:
71
  with gr.Row():
72
  with gr.Column():
73
+ image_1 = gr.Image(label="Image 1", type="pil")
74
+ image_2 = gr.Image(label="Image 2", type="pil")
75
 
76
+ model_name = gr.Dropdown(
77
+ label="Model",
78
+ choices=list(MODELS.keys()),
79
+ value=list(MODELS.keys())[0],
80
+ )
81
  submit_btn = gr.Button("Submit", variant="primary")
82
+
83
  with gr.Column():
84
+ similarity = gr.Label(label="Similarity")
85
 
86
  gr.on(
87
  triggers=[submit_btn.click],
 
91
  image_1,
92
  image_2,
93
  ],
94
+ outputs=[similarity],
95
  )
96
 
97
  if __name__ == "__main__":
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "style-demo"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "hf-xet>=1.0.3",
9
+ "safetensors>=0.5.3",
10
+ "timm>=1.0.15",
11
+ "torch>=2.6.0",
12
+ "torchvision>=0.21.0",
13
+ "transformers>=4.51.2",
14
+ ]
15
+
16
+ [dependency-groups]
17
+ dev = [
18
+ "gradio>=5.25.0",
19
+ "ruff>=0.11.5",
20
+ ]
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
 
2
  transformers
3
  timm
4
  safetensors
 
1
  torch
2
+ torchvision
3
  transformers
4
  timm
5
  safetensors
uv.lock ADDED
The diff for this file is too large to render. See raw diff