ura23 commited on
Commit
dc7b29f
·
verified ·
1 Parent(s): ce8d28d

Upload 3 files

Browse files
tagger/tagger___init__.py ADDED
File without changes
tagger/tagger_common.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from huggingface_hub import hf_hub_download
11
+ from huggingface_hub.utils import HfHubHTTPError
12
+ from PIL import Image
13
+ from torch import Tensor, nn
14
+
15
+
16
+ @dataclass
17
+ class Heatmap:
18
+ label: str
19
+ score: float
20
+ image: Image.Image
21
+
22
+
23
+ @dataclass
24
+ class LabelData:
25
+ names: list[str]
26
+ rating: list[np.int64]
27
+ general: list[np.int64]
28
+ character: list[np.int64]
29
+
30
+
31
+ @dataclass
32
+ class ImageLabels:
33
+ caption: str
34
+ booru: str
35
+ rating: dict[str, float]
36
+ general: dict[str, float]
37
+ character: dict[str, float]
38
+
39
+
40
+ @lru_cache(maxsize=5)
41
+ def load_labels_hf(
42
+ repo_id: str,
43
+ revision: Optional[str] = None,
44
+ token: Optional[str] = None,
45
+ ) -> LabelData:
46
+ try:
47
+ csv_path = hf_hub_download(
48
+ repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
49
+ )
50
+ csv_path = Path(csv_path).resolve()
51
+ except HfHubHTTPError as e:
52
+ raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
53
+
54
+ df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
55
+ tag_data = LabelData(
56
+ names=df["name"].tolist(),
57
+ rating=list(np.where(df["category"] == 9)[0]),
58
+ general=list(np.where(df["category"] == 0)[0]),
59
+ character=list(np.where(df["category"] == 4)[0]),
60
+ )
61
+
62
+ return tag_data
63
+
64
+
65
+ def mcut_threshold(probs: np.ndarray) -> float:
66
+ """
67
+ Maximum Cut Thresholding (MCut)
68
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
69
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
70
+ (pp. 172-183).
71
+ """
72
+ probs = probs[probs.argsort()[::-1]]
73
+ diffs = probs[:-1] - probs[1:]
74
+ idx = diffs.argmax()
75
+ thresh = (probs[idx] + probs[idx + 1]) / 2
76
+ return float(thresh)
77
+
78
+
79
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
80
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
81
+ if image.mode not in ["RGB", "RGBA"]:
82
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
83
+ # convert RGBA to RGB with white background
84
+ if image.mode == "RGBA":
85
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
86
+ canvas.alpha_composite(image)
87
+ image = canvas.convert("RGB")
88
+ return image
89
+
90
+
91
+ def pil_pad_square(
92
+ image: Image.Image,
93
+ fill: tuple[int, int, int] = (255, 255, 255),
94
+ ) -> Image.Image:
95
+ w, h = image.size
96
+ # get the largest dimension so we can pad to a square
97
+ px = max(image.size)
98
+ # pad to square with white background
99
+ canvas = Image.new("RGB", (px, px), fill)
100
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
101
+ return canvas
102
+
103
+
104
+ def preprocess_image(
105
+ image: Image.Image,
106
+ size_px: int | tuple[int, int],
107
+ upscale: bool = True,
108
+ ) -> Image.Image:
109
+ """
110
+ Preprocess an image to be square and centered on a white background.
111
+ """
112
+ if isinstance(size_px, int):
113
+ size_px = (size_px, size_px)
114
+
115
+ # ensure RGB and pad to square
116
+ image = pil_ensure_rgb(image)
117
+ image = pil_pad_square(image)
118
+
119
+ # resize to target size
120
+ if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
121
+ if upscale is False:
122
+ raise ValueError("Image is smaller than target size, and upscaling is disabled")
123
+ image = image.resize(size_px, Image.LANCZOS)
124
+ if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
125
+ image.thumbnail(size_px, Image.BICUBIC)
126
+
127
+ return image
128
+
129
+
130
+ def pil_make_grid(
131
+ images: list[Image.Image],
132
+ max_cols: int = 8,
133
+ padding: int = 4,
134
+ bg_color: tuple[int, int, int] = (40, 42, 54), # dracula background color
135
+ partial_rows: bool = True,
136
+ ) -> Image.Image:
137
+ n_cols = min(math.floor(math.sqrt(len(images))), max_cols)
138
+ n_rows = math.ceil(len(images) / n_cols)
139
+
140
+ # if the final row is not full and partial_rows is False, remove a row
141
+ if n_cols * n_rows > len(images) and not partial_rows:
142
+ n_rows -= 1
143
+
144
+ # assumes all images are same size
145
+ image_width, image_height = images[0].size
146
+
147
+ canvas_width = ((image_width + padding) * n_cols) + padding
148
+ canvas_height = ((image_height + padding) * n_rows) + padding
149
+
150
+ canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color)
151
+ for i, img in enumerate(images):
152
+ x = (i % n_cols) * (image_width + padding) + padding
153
+ y = (i // n_cols) * (image_height + padding) + padding
154
+ canvas.paste(img, (x, y))
155
+
156
+ return canvas
157
+
158
+
159
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
160
+ kaomojis = [
161
+ "0_0",
162
+ "(o)_(o)",
163
+ "+_+",
164
+ "+_-",
165
+ "._.",
166
+ "<o>_<o>",
167
+ "<|>_<|>",
168
+ "=_=",
169
+ ">_<",
170
+ "3_3",
171
+ "6_9",
172
+ ">_o",
173
+ "@_@",
174
+ "^_^",
175
+ "o_o",
176
+ "u_u",
177
+ "x_x",
178
+ "|_|",
179
+ "||_||",
180
+ ]
tagger/tagger_model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+
4
+ import colorcet as cc
5
+ import cv2
6
+ import numpy as np
7
+ import timm
8
+ import torch
9
+ from PIL import Image
10
+ from matplotlib.colors import LinearSegmentedColormap
11
+ from timm.data import create_transform, resolve_data_config
12
+ from timm.models import VisionTransformer
13
+ from torch import Tensor, nn
14
+ from torch.nn import functional as F
15
+ from torchvision import transforms as T
16
+
17
+ from .common import Heatmap, ImageLabels, LabelData, pil_make_grid
18
+
19
+ # working dir, either file parent dir or cwd if interactive
20
+ work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve()
21
+ temp_dir = work_dir.joinpath("temp")
22
+ temp_dir.mkdir(exist_ok=True, parents=True)
23
+
24
+ # model cache
25
+ model_cache: dict[str, VisionTransformer] = {}
26
+ transform_cache: dict[str, T.Compose] = {}
27
+
28
+ # device to use
29
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+
32
+ class RGBtoBGR(nn.Module):
33
+ def forward(self, x: Tensor) -> Tensor:
34
+ if x.ndim == 4:
35
+ return x[:, [2, 1, 0], :, :]
36
+ return x[[2, 1, 0], :, :]
37
+
38
+
39
+ def model_device(model: nn.Module) -> torch.device:
40
+ return next(model.parameters()).device
41
+
42
+
43
+ def load_model(repo_id: str) -> VisionTransformer:
44
+ global model_cache
45
+
46
+ if model_cache.get(repo_id, None) is None:
47
+ # save model to cache
48
+ model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device)
49
+
50
+ return model_cache[repo_id]
51
+
52
+
53
+ def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
54
+ global transform_cache
55
+ global model_cache
56
+
57
+ if model_cache.get(repo_id, None) is None:
58
+ # save model to cache
59
+ model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
60
+ model = model_cache[repo_id]
61
+
62
+ if transform_cache.get(repo_id, None) is None:
63
+ transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
64
+ # hack in the RGBtoBGR transform, save to cache
65
+ transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
66
+ transform = transform_cache[repo_id]
67
+
68
+ return model, transform
69
+
70
+
71
+ def get_tags(
72
+ probs: Tensor,
73
+ labels: LabelData,
74
+ gen_threshold: float,
75
+ char_threshold: float,
76
+ ):
77
+ # Convert indices+probs to labels
78
+ probs = list(zip(labels.names, probs.numpy()))
79
+
80
+ # First 4 labels are actually ratings
81
+ rating_labels = dict([probs[i] for i in labels.rating])
82
+
83
+ # General labels, pick any where prediction confidence > threshold
84
+ gen_labels = [probs[i] for i in labels.general]
85
+ gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
86
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
87
+
88
+ # Character labels, pick any where prediction confidence > threshold
89
+ char_labels = [probs[i] for i in labels.character]
90
+ char_labels = dict([x for x in char_labels if x[1] > char_threshold])
91
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
92
+
93
+ # Combine general and character labels, sort by confidence
94
+ combined_names = [x for x in gen_labels]
95
+ combined_names.extend([x for x in char_labels])
96
+
97
+ # Convert to a string suitable for use as a training caption
98
+ caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)")
99
+ booru = caption.replace("_", " ")
100
+
101
+ return caption, booru, rating_labels, char_labels, gen_labels
102
+
103
+
104
+ @torch.no_grad()
105
+ def render_heatmap(
106
+ image: Tensor,
107
+ gradients: Tensor,
108
+ image_feats: Tensor,
109
+ image_probs: Tensor,
110
+ image_labels: list[str],
111
+ cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
112
+ pos_embed_dim: int = 784,
113
+ image_size: tuple[int, int] = (448, 448),
114
+ font_args: dict = {
115
+ "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
116
+ "fontScale": 1,
117
+ "color": (255, 255, 255),
118
+ "thickness": 2,
119
+ "lineType": cv2.LINE_AA,
120
+ },
121
+ partial_rows: bool = True,
122
+ ) -> tuple[list[Heatmap], Image.Image]:
123
+ # hmap_dim = int(math.sqrt(pos_embed_dim))
124
+
125
+ image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
126
+ hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
127
+ image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
128
+ image_hmaps = image_hmaps[..., -hmap_dim ** 2:]
129
+ image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
130
+ image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
131
+
132
+ image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
133
+ # normalize to 0-1
134
+ image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1)
135
+ # interpolate to input image size
136
+ image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)
137
+
138
+ hmap_imgs: list[Heatmap] = []
139
+ for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
140
+ image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
141
+ hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
142
+
143
+ hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
144
+ hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)
145
+ if tag is not None:
146
+ cv2.putText(hmap_image, tag, (10, 30), **font_args)
147
+ cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args)
148
+
149
+ hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
150
+ hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil))
151
+
152
+ hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True)
153
+ hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows)
154
+
155
+ return hmap_imgs, hmap_grid
156
+
157
+
158
+ def process_heatmap(
159
+ model: VisionTransformer,
160
+ image: Tensor,
161
+ labels: LabelData,
162
+ threshold: float = 0.5,
163
+ partial_rows: bool = True,
164
+ ) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]:
165
+ torch_device = model_device(model)
166
+
167
+ with torch.set_grad_enabled(True):
168
+ features = model.forward_features(image.to(torch_device))
169
+ probs = model.forward_head(features)
170
+ probs = F.sigmoid(probs).squeeze(0)
171
+
172
+ probs_mask = probs > threshold
173
+ heatmap_probs = probs[probs_mask]
174
+
175
+ label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
176
+ image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))]
177
+
178
+ eye = torch.eye(heatmap_probs.shape[0], device=torch_device)
179
+ grads = torch.autograd.grad(
180
+ outputs=heatmap_probs,
181
+ inputs=features,
182
+ grad_outputs=eye,
183
+ is_grads_batched=True,
184
+ retain_graph=True,
185
+ )
186
+ grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)
187
+
188
+ with torch.set_grad_enabled(False):
189
+ hmap_imgs, hmap_grid = render_heatmap(
190
+ image=image,
191
+ gradients=grads,
192
+ image_feats=features,
193
+ image_probs=heatmap_probs,
194
+ image_labels=image_labels,
195
+ partial_rows=partial_rows,
196
+ )
197
+
198
+ caption, booru, ratings, character, general = get_tags(
199
+ probs=probs.cpu(),
200
+ labels=labels,
201
+ gen_threshold=threshold,
202
+ char_threshold=threshold,
203
+ )
204
+ labels = ImageLabels(caption, booru, ratings, general, character)
205
+
206
+ return hmap_imgs, hmap_grid, labels