ura23 commited on
Commit
4218142
·
verified ·
1 Parent(s): a6786ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -33
app.py CHANGED
@@ -17,12 +17,21 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
 
 
 
 
20
 
21
  # Dataset v2 series of models:
22
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
 
 
 
 
23
 
24
  # IdolSankaku series of models:
25
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
 
26
 
27
  # Files to download from the repos
28
  MODEL_FILENAME = "model.onnx"
@@ -115,14 +124,58 @@ def main():
115
 
116
  model_repos = [
117
  SWINV2_MODEL_DSV3_REPO,
 
 
 
 
118
  # ---
119
  MOAT_MODEL_DSV2_REPO,
 
 
 
 
120
  # ---
121
  SWINV2_MODEL_IS_DSV1_REPO,
 
122
  ]
123
 
124
- predefined_tags = ["monochrome",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  "happy_birthday",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  "light_censor",
127
  "mosaic_censoring"]
128
 
@@ -163,43 +216,82 @@ def main():
163
  with gr.Column():
164
  output = gr.Textbox(label="Output", lines=10)
165
 
166
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, expected_tags_input, default_tag_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  images = [Image.open(file.name) for file in files]
168
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
169
-
170
  # Parse filter tags
171
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
172
- expected_tags = set(tag.strip().lower() for tag in expected_tags_input.split(",")) # Define expected tags dynamically
173
- default_tag = default_tag_input.strip() # Define default tag dynamically
174
-
 
175
  # Generate formatted output
176
  prompts = []
177
- for i, (general_tags, character_tags) in enumerate(results):
178
- # Replace underscores with spaces for both character and general tags
179
- character_part = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
180
- general_part = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
181
-
182
- # Check if expected tags are present
183
- if not any(tag in expected_tags for tag in character_part + general_part):
184
- general_part.append(default_tag)
185
-
186
- # Construct the prompt based on the presence of character_part
187
- prompt = ", ".join(character_part + general_part)
188
- prompts.append(prompt)
189
-
190
- # Join all prompts with blank lines
 
191
  return "\n\n".join(prompts)
192
 
193
-
194
-
195
- submit.click(
196
- process_images,
197
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
198
- outputs=output
199
- )
200
-
201
- demo.queue(max_size=10)
202
- demo.launch()
203
-
204
- if __name__ == "__main__":
205
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
25
  # Dataset v2 series of models:
26
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
 
32
  # IdolSankaku series of models:
33
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
 
36
  # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
 
124
 
125
  model_repos = [
126
  SWINV2_MODEL_DSV3_REPO,
127
+ CONV_MODEL_DSV3_REPO,
128
+ VIT_MODEL_DSV3_REPO,
129
+ VIT_LARGE_MODEL_DSV3_REPO,
130
+ EVA02_LARGE_MODEL_DSV3_REPO,
131
  # ---
132
  MOAT_MODEL_DSV2_REPO,
133
+ SWIN_MODEL_DSV2_REPO,
134
+ CONV_MODEL_DSV2_REPO,
135
+ CONV2_MODEL_DSV2_REPO,
136
+ VIT_MODEL_DSV2_REPO,
137
  # ---
138
  SWINV2_MODEL_IS_DSV1_REPO,
139
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
  ]
141
 
142
+ predefined_tags = ["loli",
143
+ "oppai_loli",
144
+ "onee-shota",
145
+ "incest",
146
+ "furry",
147
+ "furry_female",
148
+ "shota",
149
+ "male_focus",
150
+ "signature",
151
+ "lolita_hairband",
152
+ "otoko_no_ko",
153
+ "minigirl",
154
+ "patreon_username",
155
+ "babydoll",
156
+ "monochrome",
157
  "happy_birthday",
158
+ "happy_new_year",
159
+ "dated",
160
+ "thought_bubble",
161
+ "greyscale",
162
+ "speech_bubble",
163
+ "english_text",
164
+ "copyright_name",
165
+ "twitter_username",
166
+ "patreon username",
167
+ "patreon logo",
168
+ "cover",
169
+ "content_rating"
170
+ "cover_page",
171
+ "doujin_cover",
172
+ "sex",
173
+ "artist_name",
174
+ "watermark",
175
+ "censored",
176
+ "bar_censor",
177
+ "blank_censor",
178
+ "blur_censor",
179
  "light_censor",
180
  "mosaic_censoring"]
181
 
 
216
  with gr.Column():
217
  output = gr.Textbox(label="Output", lines=10)
218
 
219
+ def parse_replacement_rules(rules_text):
220
+ """Parse user-defined tag replacement rules into a dictionary."""
221
+ rules = {}
222
+ for line in rules_text.strip().split("\n"):
223
+ if "->" in line:
224
+ old_tags, new_tags = map(str.strip, line.split("->"))
225
+ old_tags_list = tuple(map(str.strip, old_tags.lower().split(",")))
226
+ new_tags_list = [tag.strip() for tag in new_tags.split(",")]
227
+ rules[old_tags_list] = new_tags_list
228
+ return rules
229
+
230
+ def apply_replacements(tags, replacement_rules):
231
+ """Apply replacement rules to a set of tags."""
232
+ tags_set = set(tags)
233
+
234
+ for old_tags, new_tags in replacement_rules.items():
235
+ if set(old_tags).issubset(tags_set): # If all old tags exist in the set
236
+ tags_set.difference_update(old_tags) # Remove old tags
237
+ tags_set.update(new_tags) # Add new ones
238
+
239
+ return list(tags_set)
240
+
241
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text):
242
  images = [Image.open(file.name) for file in files]
243
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
244
+
245
  # Parse filter tags
246
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
247
+
248
+ # Parse user-defined replacements
249
+ replacement_rules = parse_replacement_rules(replacement_rules_text)
250
+
251
  # Generate formatted output
252
  prompts = []
253
+ for general_tags, character_tags in results:
254
+ # Apply replacements
255
+ general_tags = apply_replacements(general_tags, replacement_rules)
256
+ character_tags = apply_replacements(character_tags, replacement_rules)
257
+
258
+ # Remove filtered tags and format
259
+ general_tags = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
260
+ character_tags = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
261
+
262
+ # Construct final prompt
263
+ if character_tags:
264
+ prompts.append(f"{', '.join(character_tags)}, {', '.join(general_tags)}")
265
+ else:
266
+ prompts.append(", ".join(general_tags))
267
+
268
  return "\n\n".join(prompts)
269
 
270
+ # Modify UI to include replacement rules input
271
+ with gr.Blocks(title=TITLE) as demo:
272
+ gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
273
+ gr.Markdown(DESCRIPTION)
274
+
275
+ with gr.Row():
276
+ with gr.Column():
277
+ image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
278
+
279
+ with gr.Accordion("Advanced Settings", open=False):
280
+ model_repo = gr.Dropdown(model_repos, value=VIT_MODEL_DSV3_REPO, label="Select Model")
281
+ general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
282
+ character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
283
+ filter_tags = gr.Textbox(value=", ".join(predefined_tags), label="Filter Tags (comma-separated)", lines=3)
284
+
285
+ submit = gr.Button(value="Process Images", variant="primary")
286
+
287
+ with gr.Column():
288
+ output = gr.Textbox(label="Output", lines=10)
289
+
290
+ # Separate input for tag replacements
291
+ with gr.Accordion("Tag Replacements", open=False):
292
+ replacement_rules_text = gr.Textbox(label="Enter replacement rules (one per line)", placeholder="e.g.,\n1boy -> 1girl\nwinter, indoors, living room -> summer, outdoors", lines=5)
293
+
294
+ submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text], outputs=output)
295
+
296
+ demo.queue(max_size=10)
297
+ demo.launch()