ura23 commited on
Commit
afffbea
·
verified ·
1 Parent(s): 242b9de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -41
app.py CHANGED
@@ -16,9 +16,9 @@ Demo for the WaifuDiffusion tagger models
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
  # Dataset v3 series of models:
19
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
20
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
21
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
 
22
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
@@ -123,9 +123,9 @@ def main():
123
  predictor = Predictor()
124
 
125
  model_repos = [
126
- VIT_MODEL_DSV3_REPO,
127
  SWINV2_MODEL_DSV3_REPO,
128
  CONV_MODEL_DSV3_REPO,
 
129
  VIT_LARGE_MODEL_DSV3_REPO,
130
  EVA02_LARGE_MODEL_DSV3_REPO,
131
  # ---
@@ -177,12 +177,7 @@ def main():
177
  "blank_censor",
178
  "blur_censor",
179
  "light_censor",
180
- "mosaic_censoring"],
181
-
182
- predefined_tags2 = [
183
- "big, small:medium", # If either "big" or "small" is missing, add "medium"
184
- "small hand, large hand:medium hand" # If either "small hand" or "large hand" is missing, add "medium hand"
185
- ]
186
 
187
  with gr.Blocks(title=TITLE) as demo:
188
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
@@ -213,33 +208,21 @@ def main():
213
  placeholder="Add tags to filter out (e.g., winter, red, from above)",
214
  lines=5
215
  )
216
- custom_tags = gr.Textbox(
217
- value=", ".join(predefined_tags2),
218
- label="Custom Tags (comma-separated)",
219
- placeholder="Enter custom tags to ensure they are in the output (e.g., shy, happy:sad)",
220
- lines=3
221
- )
222
- submit = gr.Button(
223
- value="Process Images", variant="primary"
224
- )
225
 
226
  with gr.Column():
227
  output = gr.Textbox(label="Output", lines=10)
228
 
229
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, custom_tags_input):
230
  images = [Image.open(file.name) for file in files]
231
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
232
-
233
  # Parse filter tags
234
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
235
-
236
- # Parse custom tags and their fallback pairs
237
- fallback_tags = {}
238
- for pair in custom_tags_input.split(","):
239
- if ":" in pair:
240
- tag, fallback = pair.split(":")
241
- fallback_tags[tag.strip().lower()] = fallback.strip().lower()
242
-
243
  # Generate formatted output
244
  prompts = []
245
  for i, (general_tags, character_tags) in enumerate(results):
@@ -250,30 +233,24 @@ def main():
250
  general_part = ", ".join(
251
  tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
252
  )
253
-
254
- # Check if custom tags are missing and apply fallback tags
255
- all_tags = set(general_tags + character_tags)
256
- for tag, fallback in fallback_tags.items():
257
- if tag not in all_tags:
258
- all_tags.add(fallback)
259
-
260
- # Construct the final prompt
261
- final_tags = ", ".join(tag.replace('_', ' ') for tag in all_tags if tag.lower() not in filter_set)
262
- prompts.append(final_tags)
263
 
 
 
 
 
 
 
264
  # Join all prompts with blank lines
265
  return "\n\n".join(prompts)
266
 
267
-
268
  submit.click(
269
  process_images,
270
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, custom_tags],
271
  outputs=output
272
  )
273
 
274
-
275
  demo.queue(max_size=10)
276
  demo.launch()
277
 
278
  if __name__ == "__main__":
279
- main()
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
  # Dataset v3 series of models:
 
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
 
123
  predictor = Predictor()
124
 
125
  model_repos = [
 
126
  SWINV2_MODEL_DSV3_REPO,
127
  CONV_MODEL_DSV3_REPO,
128
+ VIT_MODEL_DSV3_REPO,
129
  VIT_LARGE_MODEL_DSV3_REPO,
130
  EVA02_LARGE_MODEL_DSV3_REPO,
131
  # ---
 
177
  "blank_censor",
178
  "blur_censor",
179
  "light_censor",
180
+ "mosaic_censoring"]
 
 
 
 
 
181
 
182
  with gr.Blocks(title=TITLE) as demo:
183
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
 
208
  placeholder="Add tags to filter out (e.g., winter, red, from above)",
209
  lines=5
210
  )
211
+
212
+ submit = gr.Button(
213
+ value="Process Images", variant="primary"
214
+ )
 
 
 
 
 
215
 
216
  with gr.Column():
217
  output = gr.Textbox(label="Output", lines=10)
218
 
219
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
220
  images = [Image.open(file.name) for file in files]
221
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
222
+
223
  # Parse filter tags
224
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
225
+
 
 
 
 
 
 
 
226
  # Generate formatted output
227
  prompts = []
228
  for i, (general_tags, character_tags) in enumerate(results):
 
233
  general_part = ", ".join(
234
  tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
235
  )
 
 
 
 
 
 
 
 
 
 
236
 
237
+ # Construct the prompt based on the presence of character_part
238
+ if character_part:
239
+ prompts.append(f"{character_part}, {general_part}")
240
+ else:
241
+ prompts.append(general_part)
242
+
243
  # Join all prompts with blank lines
244
  return "\n\n".join(prompts)
245
 
 
246
  submit.click(
247
  process_images,
248
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
249
  outputs=output
250
  )
251
 
 
252
  demo.queue(max_size=10)
253
  demo.launch()
254
 
255
  if __name__ == "__main__":
256
+ main()