ura23 commited on
Commit
4b081a7
·
verified ·
1 Parent(s): b038885

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -101
app.py CHANGED
@@ -15,6 +15,25 @@ Demo for the WaifuDiffusion tagger models
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  MODEL_FILENAME = "model.onnx"
19
  LABEL_FILENAME = "selected_tags.csv"
20
 
@@ -25,39 +44,11 @@ def parse_args() -> argparse.Namespace:
25
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
26
  return parser.parse_args()
27
 
28
- def parse_replacement_rules(rules_text):
29
- rules = {}
30
- for line in rules_text.strip().split("\n"):
31
- if "->" in line:
32
- old_tags, new_tags = map(str.strip, line.split("->"))
33
- old_tags_list = tuple(map(str.strip, old_tags.lower().split(",")))
34
- new_tags_list = [tag.strip() for tag in new_tags.split(",")]
35
- rules[old_tags_list] = new_tags_list
36
- return rules
37
-
38
- def parse_fallback_rules(fallback_text):
39
- fallback_rules = {}
40
- for line in fallback_text.strip().split("\n"):
41
- if "->" in line:
42
- expected_tags, fallback_tag = map(str.strip, line.split("->"))
43
- expected_tags_list = tuple(map(str.strip, expected_tags.lower().split(",")))
44
- fallback_rules[expected_tags_list] = fallback_tag.strip()
45
- return fallback_rules
46
-
47
- def apply_replacements(tags, replacement_rules):
48
- tags_set = set(tags)
49
- for old_tags, new_tags in replacement_rules.items():
50
- if set(old_tags).issubset(tags_set):
51
- tags_set.difference_update(old_tags)
52
- tags_set.update(new_tags)
53
- return list(tags_set)
54
-
55
- def apply_fallbacks(tags, fallback_rules):
56
- tags_set = set(tags)
57
- for expected_tags, fallback_tag in fallback_rules.items():
58
- if not any(tag in tags_set for tag in expected_tags):
59
- tags_set.add(fallback_tag)
60
- return list(tags_set)
61
 
62
  class Predictor:
63
  def __init__(self):
@@ -75,7 +66,7 @@ class Predictor:
75
 
76
  csv_path, model_path = self.download_model(model_repo)
77
  tags_df = pd.read_csv(csv_path)
78
- self.tag_names, self.general_indexes, self.character_indexes = tags_df["name"].tolist(), list(np.where(tags_df["category"] == 0)[0]), list(np.where(tags_df["category"] == 4)[0])
79
 
80
  model = rt.InferenceSession(model_path)
81
  _, height, width, _ = model.get_inputs()[0].shape
@@ -83,70 +74,183 @@ class Predictor:
83
  self.last_loaded_repo = model_repo
84
  self.model = model
85
 
86
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text, fallback_rules_text):
87
- images = [Image.open(file.name) for file in files]
88
- results = predictor.predict(images, model_repo, general_thresh, character_thresh)
89
-
90
- # Predefined examples
91
- predefined_filter_tags = "watermark" # This tag will be removed if detected
92
- predefined_replacement_rules = "1boy -> 1girl" # "1boy" will be replaced with "1girl"
93
- predefined_fallback_rules = "sad, happy -> smile" # If neither "sad" nor "happy" are present, add "smile"
94
-
95
- # Combine predefined rules with user input
96
- filter_tags = f"{predefined_filter_tags}, {filter_tags}".strip()
97
- replacement_rules_text = f"{predefined_replacement_rules}\n{replacement_rules_text}".strip()
98
- fallback_rules_text = f"{predefined_fallback_rules}\n{fallback_rules_text}".strip()
99
-
100
- # Parse user-defined rules
101
- filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
102
- replacement_rules = parse_replacement_rules(replacement_rules_text)
103
- fallback_rules = parse_fallback_rules(fallback_rules_text)
104
-
105
- # Generate formatted output
106
- prompts = []
107
- for general_tags, character_tags in results:
108
- general_tags = apply_replacements(general_tags, replacement_rules)
109
- character_tags = apply_replacements(character_tags, replacement_rules)
110
- general_tags = apply_fallbacks(general_tags, fallback_rules)
111
- character_tags = apply_fallbacks(character_tags, fallback_rules)
112
-
113
- general_tags = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
114
- character_tags = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
115
-
116
- prompts.append(f"{', '.join(character_tags)}, {', '.join(general_tags)}" if character_tags else ", ".join(general_tags))
117
-
118
- return "\n\n".join(prompts)
119
-
120
- args = parse_args()
121
- predictor = Predictor()
122
-
123
- model_repos = ["SmilingWolf/wd-swinv2-tagger-v3", "SmilingWolf/wd-convnext-tagger-v3", "SmilingWolf/wd-vit-tagger-v3"]
124
-
125
- with gr.Blocks(title=TITLE) as demo:
126
- gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
127
- gr.Markdown(DESCRIPTION)
128
-
129
- with gr.Row():
130
- with gr.Column():
131
- image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
132
-
133
- with gr.Accordion("Advanced Settings", open=False):
134
- model_repo = gr.Dropdown(model_repos, value="SmilingWolf/wd-vit-tagger-v3", label="Select Model")
135
- general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
136
- character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
137
-
138
- submit = gr.Button(value="Process Images", variant="primary")
139
-
140
- with gr.Column():
141
- output = gr.Textbox(label="Output", lines=10)
142
-
143
- with gr.Accordion("Tag Replacements", open=False):
144
- replacement_rules_text = gr.Textbox(label="Replacement Rules", lines=5, value="1boy -> 1girl")
145
-
146
- with gr.Accordion("Fallback Rules", open=False):
147
- fallback_rules_text = gr.Textbox(label="Fallback Rules", lines=5, value="sad, happy -> smile")
148
-
149
- submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, replacement_rules_text, fallback_rules_text], outputs=output)
150
-
151
- demo.queue()
152
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
+ # Dataset v3 series of models:
19
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
+
25
+ # Dataset v2 series of models:
26
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
+
32
+ # IdolSankaku series of models:
33
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
+
36
+ # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
38
  LABEL_FILENAME = "selected_tags.csv"
39
 
 
44
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
45
  return parser.parse_args()
46
 
47
+ def load_labels(dataframe) -> list[str]:
48
+ tag_names = dataframe["name"].tolist()
49
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
50
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
51
+ return tag_names, general_indexes, character_indexes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  class Predictor:
54
  def __init__(self):
 
66
 
67
  csv_path, model_path = self.download_model(model_repo)
68
  tags_df = pd.read_csv(csv_path)
69
+ self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df)
70
 
71
  model = rt.InferenceSession(model_path)
72
  _, height, width, _ = model.get_inputs()[0].shape
 
74
  self.last_loaded_repo = model_repo
75
  self.model = model
76
 
77
+ def prepare_image(self, image):
78
+ # Create a white canvas with the same size as the input image
79
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
80
+
81
+ # Ensure the input image has an alpha channel for compositing
82
+ if image.mode != "RGBA":
83
+ image = image.convert("RGBA")
84
+
85
+ # Composite the input image onto the canvas
86
+ canvas.alpha_composite(image)
87
+
88
+ # Convert to RGB (alpha channel is no longer needed)
89
+ image = canvas.convert("RGB")
90
+
91
+ # Resize the image to a square of size (model_target_size x model_target_size)
92
+ max_dim = max(image.size)
93
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
+ pad_left = (max_dim - image.width) // 2
95
+ pad_top = (max_dim - image.height) // 2
96
+ padded_image.paste(image, (pad_left, pad_top))
97
+ padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
+
99
+ # Convert the image to a NumPy array
100
+ image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
101
+ return np.expand_dims(image_array, axis=0)
102
+
103
+
104
+ def predict(self, images, model_repo, general_thresh, character_thresh):
105
+ self.load_model(model_repo)
106
+ results = []
107
+
108
+ for image in images:
109
+ image = self.prepare_image(image)
110
+ input_name = self.model.get_inputs()[0].name
111
+ label_name = self.model.get_outputs()[0].name
112
+ preds = self.model.run([label_name], {input_name: image})[0]
113
+
114
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
115
+ general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
116
+ character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
117
+ results.append((general_res, character_res))
118
+
119
+ return results
120
+
121
+ def main():
122
+ args = parse_args()
123
+ predictor = Predictor()
124
+
125
+ model_repos = [
126
+ SWINV2_MODEL_DSV3_REPO,
127
+ CONV_MODEL_DSV3_REPO,
128
+ VIT_MODEL_DSV3_REPO,
129
+ VIT_LARGE_MODEL_DSV3_REPO,
130
+ EVA02_LARGE_MODEL_DSV3_REPO,
131
+ # ---
132
+ MOAT_MODEL_DSV2_REPO,
133
+ SWIN_MODEL_DSV2_REPO,
134
+ CONV_MODEL_DSV2_REPO,
135
+ CONV2_MODEL_DSV2_REPO,
136
+ VIT_MODEL_DSV2_REPO,
137
+ # ---
138
+ SWINV2_MODEL_IS_DSV1_REPO,
139
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
+ ]
141
+
142
+ predefined_tags = ["loli",
143
+ "oppai_loli",
144
+ "onee-shota",
145
+ "incest",
146
+ "furry",
147
+ "furry_female",
148
+ "shota",
149
+ "male_focus",
150
+ "signature",
151
+ "lolita_hairband",
152
+ "otoko_no_ko",
153
+ "minigirl",
154
+ "patreon_username",
155
+ "babydoll",
156
+ "monochrome",
157
+ "happy_birthday",
158
+ "happy_new_year",
159
+ "dated",
160
+ "thought_bubble",
161
+ "greyscale",
162
+ "speech_bubble",
163
+ "english_text",
164
+ "copyright_name",
165
+ "twitter_username",
166
+ "patreon username",
167
+ "patreon logo",
168
+ "cover",
169
+ "content_rating"
170
+ "cover_page",
171
+ "doujin_cover",
172
+ "sex",
173
+ "artist_name",
174
+ "watermark",
175
+ "censored",
176
+ "bar_censor",
177
+ "blank_censor",
178
+ "blur_censor",
179
+ "light_censor",
180
+ "mosaic_censoring"]
181
+
182
+ with gr.Blocks(title=TITLE) as demo:
183
+ gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
184
+ gr.Markdown(DESCRIPTION)
185
+
186
+ with gr.Row():
187
+ with gr.Column():
188
+ image_files = gr.File(
189
+ file_types=["image"], label="Upload Images", file_count="multiple",
190
+ )
191
+
192
+ # Wrap the model selection and sliders in an Accordion
193
+ with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
194
+ model_repo = gr.Dropdown(
195
+ model_repos,
196
+ value=VIT_MODEL_DSV3_REPO,
197
+ label="Select Model",
198
+ )
199
+ general_thresh = gr.Slider(
200
+ 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
201
+ )
202
+ character_thresh = gr.Slider(
203
+ 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
204
+ )
205
+ filter_tags = gr.Textbox(
206
+ value=", ".join(predefined_tags),
207
+ label="Filter Tags (comma-separated)",
208
+ placeholder="Add tags to filter out (e.g., winter, red, from above)",
209
+ lines=3
210
+ )
211
+
212
+ submit = gr.Button(
213
+ value="Process Images", variant="primary"
214
+ )
215
+
216
+ with gr.Column():
217
+ output = gr.Textbox(label="Output", lines=10)
218
+
219
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
220
+ images = [Image.open(file.name) for file in files]
221
+ results = predictor.predict(images, model_repo, general_thresh, character_thresh)
222
+
223
+ # Parse filter tags
224
+ filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
225
+
226
+ # Generate formatted output
227
+ prompts = []
228
+ for i, (general_tags, character_tags) in enumerate(results):
229
+ # Replace underscores with spaces for both character and general tags
230
+ character_part = ", ".join(
231
+ tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set
232
+ )
233
+ general_part = ", ".join(
234
+ tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
235
+ )
236
+
237
+ # Construct the prompt based on the presence of character_part
238
+ if character_part:
239
+ prompts.append(f"{character_part}, {general_part}")
240
+ else:
241
+ prompts.append(general_part)
242
+
243
+ # Join all prompts with blank lines
244
+ return "\n\n".join(prompts)
245
+
246
+ submit.click(
247
+ process_images,
248
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
249
+ outputs=output
250
+ )
251
+
252
+ demo.queue(max_size=10)
253
+ demo.launch()
254
+
255
+ if __name__ == "__main__":
256
+ main()