ura23 commited on
Commit
fff4a3d
·
verified ·
1 Parent(s): 56f8cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -177
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import argparse
2
  import os
3
-
4
  import gradio as gr
5
  import huggingface_hub
6
  import numpy as np
@@ -9,42 +8,24 @@ import pandas as pd
9
  from PIL import Image
10
 
11
  TITLE = "WaifuDiffusion Tagger"
12
- DESCRIPTION = """
13
- Demo for the WaifuDiffusion tagger models
14
- """
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
- # Dataset v3 series of models:
19
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
-
25
- # Dataset v2 series of models:
26
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
 
32
- # IdolSankaku series of models:
33
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
-
36
- # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
38
  LABEL_FILENAME = "selected_tags.csv"
39
 
40
- def parse_args() -> argparse.Namespace:
41
  parser = argparse.ArgumentParser()
42
  parser.add_argument("--score-slider-step", type=float, default=0.05)
43
  parser.add_argument("--score-general-threshold", type=float, default=0.3)
44
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
45
  return parser.parse_args()
46
 
47
- def load_labels(dataframe) -> list[str]:
48
  tag_names = dataframe["name"].tolist()
49
  general_indexes = list(np.where(dataframe["category"] == 0)[0])
50
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
@@ -75,20 +56,6 @@ class Predictor:
75
  self.model = model
76
 
77
  def prepare_image(self, image):
78
- # Create a white canvas with the same size as the input image
79
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
80
-
81
- # Ensure the input image has an alpha channel for compositing
82
- if image.mode != "RGBA":
83
- image = image.convert("RGBA")
84
-
85
- # Composite the input image onto the canvas
86
- canvas.alpha_composite(image)
87
-
88
- # Convert to RGB (alpha channel is no longer needed)
89
- image = canvas.convert("RGB")
90
-
91
- # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
@@ -96,10 +63,7 @@ class Predictor:
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
99
- # Convert the image to a NumPy array
100
- image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
101
- return np.expand_dims(image_array, axis=0)
102
-
103
 
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
@@ -118,139 +82,102 @@ class Predictor:
118
 
119
  return results
120
 
121
- def main():
122
- args = parse_args()
123
- predictor = Predictor()
124
-
125
- model_repos = [
126
- SWINV2_MODEL_DSV3_REPO,
127
- CONV_MODEL_DSV3_REPO,
128
- VIT_MODEL_DSV3_REPO,
129
- VIT_LARGE_MODEL_DSV3_REPO,
130
- EVA02_LARGE_MODEL_DSV3_REPO,
131
- # ---
132
- MOAT_MODEL_DSV2_REPO,
133
- SWIN_MODEL_DSV2_REPO,
134
- CONV_MODEL_DSV2_REPO,
135
- CONV2_MODEL_DSV2_REPO,
136
- VIT_MODEL_DSV2_REPO,
137
- # ---
138
- SWINV2_MODEL_IS_DSV1_REPO,
139
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
- ]
141
-
142
- predefined_tags = ["loli",
143
- "oppai_loli",
144
- "onee-shota",
145
- "incest",
146
- "furry",
147
- "furry_female",
148
- "shota",
149
- "male_focus",
150
- "signature",
151
- "lolita_hairband",
152
- "otoko_no_ko",
153
- "minigirl",
154
- "patreon_username",
155
- "babydoll",
156
- "monochrome",
157
- "happy_birthday",
158
- "happy_new_year",
159
- "dated",
160
- "thought_bubble",
161
- "greyscale",
162
- "speech_bubble",
163
- "english_text",
164
- "copyright_name",
165
- "twitter_username",
166
- "patreon username",
167
- "patreon logo",
168
- "cover",
169
- "content_rating"
170
- "cover_page",
171
- "doujin_cover",
172
- "sex",
173
- "artist_name",
174
- "watermark",
175
- "censored",
176
- "bar_censor",
177
- "blank_censor",
178
- "blur_censor",
179
- "light_censor",
180
- "mosaic_censoring"]
181
-
182
- with gr.Blocks(title=TITLE) as demo:
183
- gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
184
- gr.Markdown(DESCRIPTION)
185
-
186
- with gr.Row():
187
- with gr.Column():
188
- image_files = gr.File(
189
- file_types=["image"], label="Upload Images", file_count="multiple",
190
- )
191
-
192
- # Wrap the model selection and sliders in an Accordion
193
- with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
194
- model_repo = gr.Dropdown(
195
- model_repos,
196
- value=VIT_MODEL_DSV3_REPO,
197
- label="Select Model",
198
- )
199
- general_thresh = gr.Slider(
200
- 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
201
- )
202
- character_thresh = gr.Slider(
203
- 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
204
- )
205
- filter_tags = gr.Textbox(
206
- value=", ".join(predefined_tags),
207
- label="Filter Tags (comma-separated)",
208
- placeholder="Add tags to filter out (e.g., winter, red, from above)",
209
- lines=3
210
- )
211
-
212
- submit = gr.Button(
213
- value="Process Images", variant="primary"
214
- )
215
-
216
- with gr.Column():
217
- output = gr.Textbox(label="Output", lines=10)
218
-
219
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
220
- images = [Image.open(file.name) for file in files]
221
- results = predictor.predict(images, model_repo, general_thresh, character_thresh)
222
-
223
- # Parse filter tags
224
- filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
225
-
226
- # Generate formatted output
227
- prompts = []
228
- for i, (general_tags, character_tags) in enumerate(results):
229
- # Replace underscores with spaces for both character and general tags
230
- character_part = ", ".join(
231
- tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set
232
- )
233
- general_part = ", ".join(
234
- tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
235
- )
236
-
237
- # Construct the prompt based on the presence of character_part
238
- if character_part:
239
- prompts.append(f"{character_part}, {general_part}")
240
- else:
241
- prompts.append(general_part)
242
-
243
- # Join all prompts with blank lines
244
- return "\n\n".join(prompts)
245
-
246
- submit.click(
247
- process_images,
248
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
249
- outputs=output
250
- )
251
-
252
- demo.queue(max_size=10)
253
- demo.launch()
254
-
255
- if __name__ == "__main__":
256
- main()
 
1
  import argparse
2
  import os
 
3
  import gradio as gr
4
  import huggingface_hub
5
  import numpy as np
 
8
  from PIL import Image
9
 
10
  TITLE = "WaifuDiffusion Tagger"
11
+ DESCRIPTION = "Demo for the WaifuDiffusion tagger models"
 
 
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
14
 
15
+ # Model Repositories
 
 
16
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
18
  MODEL_FILENAME = "model.onnx"
19
  LABEL_FILENAME = "selected_tags.csv"
20
 
21
+ def parse_args():
22
  parser = argparse.ArgumentParser()
23
  parser.add_argument("--score-slider-step", type=float, default=0.05)
24
  parser.add_argument("--score-general-threshold", type=float, default=0.3)
25
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
26
  return parser.parse_args()
27
 
28
+ def load_labels(dataframe):
29
  tag_names = dataframe["name"].tolist()
30
  general_indexes = list(np.where(dataframe["category"] == 0)[0])
31
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
 
56
  self.model = model
57
 
58
  def prepare_image(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  max_dim = max(image.size)
60
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
61
  pad_left = (max_dim - image.width) // 2
 
63
  padded_image.paste(image, (pad_left, pad_top))
64
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
65
 
66
+ return np.expand_dims(np.asarray(padded_image, dtype=np.float32)[:, :, ::-1], axis=0)
 
 
 
67
 
68
  def predict(self, images, model_repo, general_thresh, character_thresh):
69
  self.load_model(model_repo)
 
82
 
83
  return results
84
 
85
+ predictor = Predictor()
86
+
87
+ def parse_replacement_rules(rules_text):
88
+ rules = {}
89
+ for line in rules_text.strip().split("\n"):
90
+ if "->" in line:
91
+ old_tags, new_tags = map(str.strip, line.split("->"))
92
+ old_tags_list = tuple(map(str.strip, old_tags.lower().split(",")))
93
+ new_tags_list = [tag.strip() for tag in new_tags.split(",")]
94
+ rules[old_tags_list] = new_tags_list
95
+ return rules
96
+
97
+ def parse_fallback_rules(fallback_text):
98
+ fallback_rules = {}
99
+ for line in fallback_text.strip().split("\n"):
100
+ if "->" in line:
101
+ expected_tags, fallback_tag = map(str.strip, line.split("->"))
102
+ expected_tags_list = tuple(map(str.strip, expected_tags.lower().split(",")))
103
+ fallback_rules[expected_tags_list] = fallback_tag.strip()
104
+ return fallback_rules
105
+
106
+ def apply_replacements(tags, replacement_rules):
107
+ tags_set = set(tags)
108
+
109
+ for old_tags, new_tags in replacement_rules.items():
110
+ if set(old_tags).issubset(tags_set):
111
+ tags_set.difference_update(old_tags)
112
+ tags_set.update(new_tags)
113
+
114
+ return list(tags_set)
115
+
116
+ def apply_fallbacks(tags, fallback_rules):
117
+ tags_set = set(tags)
118
+
119
+ for expected_tags, fallback_tag in fallback_rules.items():
120
+ if not any(tag in tags_set for tag in expected_tags):
121
+ tags_set.add(fallback_tag)
122
+
123
+ return list(tags_set)
124
+
125
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text, fallback_rules_text):
126
+ images = [Image.open(file.name) for file in files]
127
+ results = predictor.predict(images, model_repo, general_thresh, character_thresh)
128
+
129
+ filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
130
+ replacement_rules = parse_replacement_rules(replacement_rules_text)
131
+ fallback_rules = parse_fallback_rules(fallback_rules_text)
132
+
133
+ prompts = []
134
+ for general_tags, character_tags in results:
135
+ general_tags = apply_replacements(general_tags, replacement_rules)
136
+ character_tags = apply_replacements(character_tags, replacement_rules)
137
+
138
+ general_tags = apply_fallbacks(general_tags, fallback_rules)
139
+ character_tags = apply_fallbacks(character_tags, fallback_rules)
140
+
141
+ general_tags = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
142
+ character_tags = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
143
+
144
+ if character_tags:
145
+ prompts.append(f"{', '.join(character_tags)}, {', '.join(general_tags)}")
146
+ else:
147
+ prompts.append(", ".join(general_tags))
148
+
149
+ return "\n\n".join(prompts)
150
+
151
+ args = parse_args()
152
+
153
+ with gr.Blocks(title=TITLE) as demo:
154
+ gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
155
+ gr.Markdown(DESCRIPTION)
156
+
157
+ with gr.Row():
158
+ with gr.Column():
159
+ image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
160
+
161
+ with gr.Accordion("Advanced Settings", open=False):
162
+ model_repo = gr.Dropdown([VIT_MODEL_DSV3_REPO], value=VIT_MODEL_DSV3_REPO, label="Select Model")
163
+ general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
164
+ character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
165
+ filter_tags = gr.Textbox(label="Filter Tags (comma-separated)", lines=3)
166
+
167
+ submit = gr.Button(value="Process Images", variant="primary")
168
+
169
+ with gr.Column():
170
+ output = gr.Textbox(label="Output", lines=10)
171
+
172
+ with gr.Accordion("Tag Replacements", open=False):
173
+ replacement_rules_text = gr.Textbox(label="Replacement Rules", placeholder="e.g., 1boy -> 1girl", lines=5)
174
+
175
+ with gr.Accordion("Fallback Rules", open=False):
176
+ fallback_rules_text = gr.Textbox(label="Fallback Rules", placeholder="e.g., sad, happy -> smile", lines=5)
177
+
178
+ submit.click(process_images,
179
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text, fallback_rules_text],
180
+ outputs=output)
181
+
182
+ demo.queue(max_size=10)
183
+ demo.launch()