ura23 commited on
Commit
ce8d28d
·
verified ·
1 Parent(s): ac2c511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -117
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  import os
 
3
 
4
  import gradio as gr
5
  import huggingface_hub
@@ -7,10 +8,12 @@ import numpy as np
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
 
 
10
 
11
- TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
13
- Demo for the WaifuDiffusion tagger models
14
  """
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
@@ -22,34 +25,17 @@ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
25
- # Dataset v2 series of models:
26
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
 
31
 
32
- # IdolSankaku series of models:
33
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
-
36
- # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
38
  LABEL_FILENAME = "selected_tags.csv"
39
 
40
- def parse_args() -> argparse.Namespace:
41
- parser = argparse.ArgumentParser()
42
- parser.add_argument("--score-slider-step", type=float, default=0.05)
43
- parser.add_argument("--score-general-threshold", type=float, default=0.3)
44
- parser.add_argument("--score-character-threshold", type=float, default=1.0)
45
- return parser.parse_args()
46
-
47
- def load_labels(dataframe) -> list[str]:
48
- tag_names = dataframe["name"].tolist()
49
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
50
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
51
- return tag_names, general_indexes, character_indexes
52
-
53
  class Predictor:
54
  def __init__(self):
55
  self.model_target_size = None
@@ -66,7 +52,7 @@ class Predictor:
66
 
67
  csv_path, model_path = self.download_model(model_repo)
68
  tags_df = pd.read_csv(csv_path)
69
- self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df)
70
 
71
  model = rt.InferenceSession(model_path)
72
  _, height, width, _ = model.get_inputs()[0].shape
@@ -74,21 +60,19 @@ class Predictor:
74
  self.last_loaded_repo = model_repo
75
  self.model = model
76
 
 
 
 
 
 
 
77
  def prepare_image(self, image):
78
- # Create a white canvas with the same size as the input image
79
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
80
-
81
- # Ensure the input image has an alpha channel for compositing
82
  if image.mode != "RGBA":
83
  image = image.convert("RGBA")
84
-
85
- # Composite the input image onto the canvas
86
  canvas.alpha_composite(image)
87
-
88
- # Convert to RGB (alpha channel is no longer needed)
89
  image = canvas.convert("RGB")
90
 
91
- # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
@@ -96,10 +80,7 @@ class Predictor:
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
99
- # Convert the image to a NumPy array
100
- image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
101
- return np.expand_dims(image_array, axis=0)
102
-
103
 
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
@@ -118,101 +99,58 @@ class Predictor:
118
 
119
  return results
120
 
 
 
 
 
 
 
 
 
121
  def main():
122
- args = parse_args()
123
  predictor = Predictor()
124
 
125
- model_repos = [
126
- SWINV2_MODEL_DSV3_REPO,
127
- CONV_MODEL_DSV3_REPO,
128
- VIT_MODEL_DSV3_REPO,
129
- VIT_LARGE_MODEL_DSV3_REPO,
130
- EVA02_LARGE_MODEL_DSV3_REPO,
131
- # ---
132
- MOAT_MODEL_DSV2_REPO,
133
- SWIN_MODEL_DSV2_REPO,
134
- CONV_MODEL_DSV2_REPO,
135
- CONV2_MODEL_DSV2_REPO,
136
- VIT_MODEL_DSV2_REPO,
137
- # ---
138
- SWINV2_MODEL_IS_DSV1_REPO,
139
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
- ]
141
-
142
- predefined_tags = ["loli", "oppai_loli", "minigirl", "babydoll", "monochrome", "greyscale", "speech_bubble", "english_text", "copyright_name", "twitter_username", "artist_name", "watermark", "censored", "bar_censor", "blank_censor", "blur_censor", "light_censor", "mosaic_censoring"] # Default tags to filter out
143
-
144
  with gr.Blocks(title=TITLE) as demo:
145
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
146
  gr.Markdown(DESCRIPTION)
147
 
148
  with gr.Row():
149
  with gr.Column():
150
- image_files = gr.File(
151
- file_types=["image"], label="Upload Images", file_count="multiple",
152
- )
153
-
154
- # Wrap the model selection and sliders in an Accordion
155
- with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
156
- model_repo = gr.Dropdown(
157
- model_repos,
158
- value=VIT_MODEL_DSV3_REPO,
159
- label="Select Model",
160
- )
161
- general_thresh = gr.Slider(
162
- 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
163
- )
164
- character_thresh = gr.Slider(
165
- 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
166
- )
167
- filter_tags = gr.Textbox(
168
- value=", ".join(predefined_tags),
169
- label="Filter Tags (comma-separated)",
170
- placeholder="Add tags to filter out (e.g., winter, red, from above)",
171
- lines=3
172
- )
173
-
174
- submit = gr.Button(
175
- value="Process Images", variant="primary"
176
- )
177
 
178
  with gr.Column():
179
- output = gr.Textbox(label="Output", lines=10)
180
-
181
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
 
 
 
 
 
182
  images = [Image.open(file.name) for file in files]
183
- results = predictor.predict(images, model_repo, general_thresh, character_thresh)
184
-
185
- # Parse filter tags
186
- filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
187
-
188
- # Generate formatted output
189
- prompts = []
190
- for i, (general_tags, character_tags) in enumerate(results):
191
- # Replace underscores with spaces for both character and general tags
192
- character_part = ", ".join(
193
- tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set
194
- )
195
- general_part = ", ".join(
196
- tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
197
- )
198
-
199
- # Construct the prompt based on the presence of character_part
200
- if character_part:
201
- prompts.append(f"{character_part}, {general_part}")
202
- else:
203
- prompts.append(general_part)
204
-
205
- # Join all prompts with blank lines
206
- return "\n\n".join(prompts)
207
 
208
  submit.click(
209
  process_images,
210
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
211
- outputs=output
212
  )
213
 
214
- demo.queue(max_size=10)
215
  demo.launch()
216
 
217
  if __name__ == "__main__":
218
- main()
 
1
  import argparse
2
  import os
3
+ from pathlib import Path
4
 
5
  import gradio as gr
6
  import huggingface_hub
 
8
  import onnxruntime as rt
9
  import pandas as pd
10
  from PIL import Image
11
+ from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
12
+ from tagger.model import load_model_and_transform, process_heatmap
13
 
14
+ TITLE = "WaifuDiffusion Tagger with Heatmap"
15
  DESCRIPTION = """
16
+ Demo for the WaifuDiffusion tagger models with heatmap and grid visualization.
17
  """
18
 
19
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
 
25
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
26
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
27
 
28
+ MODEL_REPOS = [
29
+ SWINV2_MODEL_DSV3_REPO,
30
+ CONV_MODEL_DSV3_REPO,
31
+ VIT_MODEL_DSV3_REPO,
32
+ VIT_LARGE_MODEL_DSV3_REPO,
33
+ EVA02_LARGE_MODEL_DSV3_REPO,
34
+ ]
35
 
 
 
 
 
 
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class Predictor:
40
  def __init__(self):
41
  self.model_target_size = None
 
52
 
53
  csv_path, model_path = self.download_model(model_repo)
54
  tags_df = pd.read_csv(csv_path)
55
+ self.tag_names, self.general_indexes, self.character_indexes = self.load_labels(tags_df)
56
 
57
  model = rt.InferenceSession(model_path)
58
  _, height, width, _ = model.get_inputs()[0].shape
 
60
  self.last_loaded_repo = model_repo
61
  self.model = model
62
 
63
+ def load_labels(self, dataframe):
64
+ tag_names = dataframe["name"].tolist()
65
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
66
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
67
+ return tag_names, general_indexes, character_indexes
68
+
69
  def prepare_image(self, image):
 
70
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
 
 
71
  if image.mode != "RGBA":
72
  image = image.convert("RGBA")
 
 
73
  canvas.alpha_composite(image)
 
 
74
  image = canvas.convert("RGB")
75
 
 
76
  max_dim = max(image.size)
77
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
78
  pad_left = (max_dim - image.width) // 2
 
80
  padded_image.paste(image, (pad_left, pad_top))
81
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
82
 
83
+ return np.expand_dims(np.asarray(padded_image, dtype=np.float32)[:, :, ::-1], axis=0)
 
 
 
84
 
85
  def predict(self, images, model_repo, general_thresh, character_thresh):
86
  self.load_model(model_repo)
 
99
 
100
  return results
101
 
102
+ def generate_heatmap_and_grid(self, image, model_repo, threshold):
103
+ model, transform = load_model_and_transform(model_repo)
104
+ labels = load_labels_hf(model_repo)
105
+ image = preprocess_image(image, (448, 448))
106
+ image = transform(image).unsqueeze(0)
107
+ heatmaps, heatmap_grid, _ = process_heatmap(model, image, labels, threshold)
108
+ return [(x.image, x.label) for x in heatmaps], heatmap_grid
109
+
110
  def main():
 
111
  predictor = Predictor()
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  with gr.Blocks(title=TITLE) as demo:
114
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
115
  gr.Markdown(DESCRIPTION)
116
 
117
  with gr.Row():
118
  with gr.Column():
119
+ image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
120
+ model_repo = gr.Dropdown(MODEL_REPOS, value=VIT_MODEL_DSV3_REPO, label="Select Model")
121
+ threshold = gr.Slider(0, 1, step=0.01, value=0.35, label="Heatmap Threshold")
122
+ general_thresh = gr.Slider(0, 1, step=0.05, value=0.3, label="General Tags Threshold")
123
+ character_thresh = gr.Slider(0, 1, step=0.05, value=1.0, label="Character Tags Threshold")
124
+ submit = gr.Button(value="Process Images", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  with gr.Column():
127
+ with gr.Tab(label="Tags"):
128
+ output_tags = gr.Textbox(label="Output Tags", lines=10)
129
+ with gr.Tab(label="Heatmaps"):
130
+ heatmap_gallery = gr.Gallery(label="Heatmap Gallery")
131
+ with gr.Tab(label="Grid"):
132
+ heatmap_grid = gr.Image(label="Heatmap Grid")
133
+
134
+ def process_images(files, model_repo, general_thresh, character_thresh, threshold):
135
  images = [Image.open(file.name) for file in files]
136
+ tag_results = predictor.predict(images, model_repo, general_thresh, character_thresh)
137
+ heatmap_results, grid_result = predictor.generate_heatmap_and_grid(images[0], model_repo, threshold)
138
+
139
+ tag_output = []
140
+ for general_tags, character_tags in tag_results:
141
+ general_str = ", ".join(general_tags)
142
+ character_str = ", ".join(character_tags)
143
+ tag_output.append(f"Characters: {character_str}\nGeneral: {general_str}")
144
+
145
+ return "\n\n".join(tag_output), heatmap_results, grid_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  submit.click(
148
  process_images,
149
+ inputs=[image_files, model_repo, general_thresh, character_thresh, threshold],
150
+ outputs=[output_tags, heatmap_gallery, heatmap_grid]
151
  )
152
 
 
153
  demo.launch()
154
 
155
  if __name__ == "__main__":
156
+ main()