ura23 commited on
Commit
8d32a56
·
verified ·
1 Parent(s): dc7b29f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -55
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import argparse
2
  import os
3
- from pathlib import Path
4
 
5
  import gradio as gr
6
  import huggingface_hub
@@ -8,12 +7,10 @@ import numpy as np
8
  import onnxruntime as rt
9
  import pandas as pd
10
  from PIL import Image
11
- from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
12
- from tagger.model import load_model_and_transform, process_heatmap
13
 
14
- TITLE = "WaifuDiffusion Tagger with Heatmap"
15
  DESCRIPTION = """
16
- Demo for the WaifuDiffusion tagger models with heatmap and grid visualization.
17
  """
18
 
19
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
@@ -25,17 +22,34 @@ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
25
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
26
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
27
 
28
- MODEL_REPOS = [
29
- SWINV2_MODEL_DSV3_REPO,
30
- CONV_MODEL_DSV3_REPO,
31
- VIT_MODEL_DSV3_REPO,
32
- VIT_LARGE_MODEL_DSV3_REPO,
33
- EVA02_LARGE_MODEL_DSV3_REPO,
34
- ]
35
 
 
 
 
 
 
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class Predictor:
40
  def __init__(self):
41
  self.model_target_size = None
@@ -52,7 +66,7 @@ class Predictor:
52
 
53
  csv_path, model_path = self.download_model(model_repo)
54
  tags_df = pd.read_csv(csv_path)
55
- self.tag_names, self.general_indexes, self.character_indexes = self.load_labels(tags_df)
56
 
57
  model = rt.InferenceSession(model_path)
58
  _, height, width, _ = model.get_inputs()[0].shape
@@ -60,19 +74,21 @@ class Predictor:
60
  self.last_loaded_repo = model_repo
61
  self.model = model
62
 
63
- def load_labels(self, dataframe):
64
- tag_names = dataframe["name"].tolist()
65
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
66
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
67
- return tag_names, general_indexes, character_indexes
68
-
69
  def prepare_image(self, image):
 
70
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
 
 
71
  if image.mode != "RGBA":
72
  image = image.convert("RGBA")
 
 
73
  canvas.alpha_composite(image)
 
 
74
  image = canvas.convert("RGB")
75
 
 
76
  max_dim = max(image.size)
77
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
78
  pad_left = (max_dim - image.width) // 2
@@ -80,7 +96,10 @@ class Predictor:
80
  padded_image.paste(image, (pad_left, pad_top))
81
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
82
 
83
- return np.expand_dims(np.asarray(padded_image, dtype=np.float32)[:, :, ::-1], axis=0)
 
 
 
84
 
85
  def predict(self, images, model_repo, general_thresh, character_thresh):
86
  self.load_model(model_repo)
@@ -99,58 +118,101 @@ class Predictor:
99
 
100
  return results
101
 
102
- def generate_heatmap_and_grid(self, image, model_repo, threshold):
103
- model, transform = load_model_and_transform(model_repo)
104
- labels = load_labels_hf(model_repo)
105
- image = preprocess_image(image, (448, 448))
106
- image = transform(image).unsqueeze(0)
107
- heatmaps, heatmap_grid, _ = process_heatmap(model, image, labels, threshold)
108
- return [(x.image, x.label) for x in heatmaps], heatmap_grid
109
-
110
  def main():
 
111
  predictor = Predictor()
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  with gr.Blocks(title=TITLE) as demo:
114
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
115
  gr.Markdown(DESCRIPTION)
116
 
117
  with gr.Row():
118
  with gr.Column():
119
- image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
120
- model_repo = gr.Dropdown(MODEL_REPOS, value=VIT_MODEL_DSV3_REPO, label="Select Model")
121
- threshold = gr.Slider(0, 1, step=0.01, value=0.35, label="Heatmap Threshold")
122
- general_thresh = gr.Slider(0, 1, step=0.05, value=0.3, label="General Tags Threshold")
123
- character_thresh = gr.Slider(0, 1, step=0.05, value=1.0, label="Character Tags Threshold")
124
- submit = gr.Button(value="Process Images", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  with gr.Column():
127
- with gr.Tab(label="Tags"):
128
- output_tags = gr.Textbox(label="Output Tags", lines=10)
129
- with gr.Tab(label="Heatmaps"):
130
- heatmap_gallery = gr.Gallery(label="Heatmap Gallery")
131
- with gr.Tab(label="Grid"):
132
- heatmap_grid = gr.Image(label="Heatmap Grid")
133
-
134
- def process_images(files, model_repo, general_thresh, character_thresh, threshold):
135
- images = [Image.open(file.name) for file in files]
136
- tag_results = predictor.predict(images, model_repo, general_thresh, character_thresh)
137
- heatmap_results, grid_result = predictor.generate_heatmap_and_grid(images[0], model_repo, threshold)
138
 
139
- tag_output = []
140
- for general_tags, character_tags in tag_results:
141
- general_str = ", ".join(general_tags)
142
- character_str = ", ".join(character_tags)
143
- tag_output.append(f"Characters: {character_str}\nGeneral: {general_str}")
144
-
145
- return "\n\n".join(tag_output), heatmap_results, grid_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  submit.click(
148
  process_images,
149
- inputs=[image_files, model_repo, general_thresh, character_thresh, threshold],
150
- outputs=[output_tags, heatmap_gallery, heatmap_grid]
151
  )
152
 
 
153
  demo.launch()
154
 
155
  if __name__ == "__main__":
156
- main()
 
1
  import argparse
2
  import os
 
3
 
4
  import gradio as gr
5
  import huggingface_hub
 
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
 
 
10
 
11
+ TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
14
  """
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
 
22
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
25
+ # Dataset v2 series of models:
26
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
 
31
 
32
+ # IdolSankaku series of models:
33
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
+
36
+ # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
38
  LABEL_FILENAME = "selected_tags.csv"
39
 
40
+ def parse_args() -> argparse.Namespace:
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
43
+ parser.add_argument("--score-general-threshold", type=float, default=0.3)
44
+ parser.add_argument("--score-character-threshold", type=float, default=1.0)
45
+ return parser.parse_args()
46
+
47
+ def load_labels(dataframe) -> list[str]:
48
+ tag_names = dataframe["name"].tolist()
49
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
50
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
51
+ return tag_names, general_indexes, character_indexes
52
+
53
  class Predictor:
54
  def __init__(self):
55
  self.model_target_size = None
 
66
 
67
  csv_path, model_path = self.download_model(model_repo)
68
  tags_df = pd.read_csv(csv_path)
69
+ self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df)
70
 
71
  model = rt.InferenceSession(model_path)
72
  _, height, width, _ = model.get_inputs()[0].shape
 
74
  self.last_loaded_repo = model_repo
75
  self.model = model
76
 
 
 
 
 
 
 
77
  def prepare_image(self, image):
78
+ # Create a white canvas with the same size as the input image
79
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
80
+
81
+ # Ensure the input image has an alpha channel for compositing
82
  if image.mode != "RGBA":
83
  image = image.convert("RGBA")
84
+
85
+ # Composite the input image onto the canvas
86
  canvas.alpha_composite(image)
87
+
88
+ # Convert to RGB (alpha channel is no longer needed)
89
  image = canvas.convert("RGB")
90
 
91
+ # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
 
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
99
+ # Convert the image to a NumPy array
100
+ image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
101
+ return np.expand_dims(image_array, axis=0)
102
+
103
 
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
 
118
 
119
  return results
120
 
 
 
 
 
 
 
 
 
121
  def main():
122
+ args = parse_args()
123
  predictor = Predictor()
124
 
125
+ model_repos = [
126
+ SWINV2_MODEL_DSV3_REPO,
127
+ CONV_MODEL_DSV3_REPO,
128
+ VIT_MODEL_DSV3_REPO,
129
+ VIT_LARGE_MODEL_DSV3_REPO,
130
+ EVA02_LARGE_MODEL_DSV3_REPO,
131
+ # ---
132
+ MOAT_MODEL_DSV2_REPO,
133
+ SWIN_MODEL_DSV2_REPO,
134
+ CONV_MODEL_DSV2_REPO,
135
+ CONV2_MODEL_DSV2_REPO,
136
+ VIT_MODEL_DSV2_REPO,
137
+ # ---
138
+ SWINV2_MODEL_IS_DSV1_REPO,
139
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
+ ]
141
+
142
+ predefined_tags = ["loli", "oppai_loli", "minigirl", "babydoll", "monochrome", "greyscale", "speech_bubble", "english_text", "copyright_name", "twitter_username", "artist_name", "watermark", "censored", "bar_censor", "blank_censor", "blur_censor", "light_censor", "mosaic_censoring"] # Default tags to filter out
143
+
144
  with gr.Blocks(title=TITLE) as demo:
145
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
146
  gr.Markdown(DESCRIPTION)
147
 
148
  with gr.Row():
149
  with gr.Column():
150
+ image_files = gr.File(
151
+ file_types=["image"], label="Upload Images", file_count="multiple",
152
+ )
153
+
154
+ # Wrap the model selection and sliders in an Accordion
155
+ with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
156
+ model_repo = gr.Dropdown(
157
+ model_repos,
158
+ value=VIT_MODEL_DSV3_REPO,
159
+ label="Select Model",
160
+ )
161
+ general_thresh = gr.Slider(
162
+ 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
163
+ )
164
+ character_thresh = gr.Slider(
165
+ 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
166
+ )
167
+ filter_tags = gr.Textbox(
168
+ value=", ".join(predefined_tags),
169
+ label="Filter Tags (comma-separated)",
170
+ placeholder="Add tags to filter out (e.g., winter, red, from above)",
171
+ lines=3
172
+ )
173
+
174
+ submit = gr.Button(
175
+ value="Process Images", variant="primary"
176
+ )
177
 
178
  with gr.Column():
179
+ output = gr.Textbox(label="Output", lines=10)
 
 
 
 
 
 
 
 
 
 
180
 
181
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
182
+ images = [Image.open(file.name) for file in files]
183
+ results = predictor.predict(images, model_repo, general_thresh, character_thresh)
184
+
185
+ # Parse filter tags
186
+ filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
187
+
188
+ # Generate formatted output
189
+ prompts = []
190
+ for i, (general_tags, character_tags) in enumerate(results):
191
+ # Replace underscores with spaces for both character and general tags
192
+ character_part = ", ".join(
193
+ tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set
194
+ )
195
+ general_part = ", ".join(
196
+ tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
197
+ )
198
+
199
+ # Construct the prompt based on the presence of character_part
200
+ if character_part:
201
+ prompts.append(f"{character_part}, {general_part}")
202
+ else:
203
+ prompts.append(general_part)
204
+
205
+ # Join all prompts with blank lines
206
+ return "\n\n".join(prompts)
207
 
208
  submit.click(
209
  process_images,
210
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
211
+ outputs=output
212
  )
213
 
214
+ demo.queue(max_size=10)
215
  demo.launch()
216
 
217
  if __name__ == "__main__":
218
+ main()