ura23 commited on
Commit
b038885
·
verified ·
1 Parent(s): d1d64c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -15,24 +15,6 @@ Demo for the WaifuDiffusion tagger models
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
- # Dataset v3 series of models:
19
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
-
25
- # Dataset v2 series of models:
26
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
-
32
- # IdolSankaku series of models:
33
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
-
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
@@ -105,10 +87,22 @@ def process_images(files, model_repo, general_thresh, character_thresh, filter_t
105
  images = [Image.open(file.name) for file in files]
106
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
107
 
 
 
 
 
 
 
 
 
 
 
 
108
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
109
  replacement_rules = parse_replacement_rules(replacement_rules_text)
110
  fallback_rules = parse_fallback_rules(fallback_rules_text)
111
 
 
112
  prompts = []
113
  for general_tags, character_tags in results:
114
  general_tags = apply_replacements(general_tags, replacement_rules)
@@ -126,7 +120,7 @@ def process_images(files, model_repo, general_thresh, character_thresh, filter_t
126
  args = parse_args()
127
  predictor = Predictor()
128
 
129
- model_repos = [SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO, SWINV2_MODEL_IS_DSV1_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO]
130
 
131
  with gr.Blocks(title=TITLE) as demo:
132
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
@@ -137,7 +131,7 @@ with gr.Blocks(title=TITLE) as demo:
137
  image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
138
 
139
  with gr.Accordion("Advanced Settings", open=False):
140
- model_repo = gr.Dropdown(model_repos, value=VIT_MODEL_DSV3_REPO, label="Select Model")
141
  general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
142
  character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
143
 
@@ -147,10 +141,10 @@ with gr.Blocks(title=TITLE) as demo:
147
  output = gr.Textbox(label="Output", lines=10)
148
 
149
  with gr.Accordion("Tag Replacements", open=False):
150
- replacement_rules_text = gr.Textbox(label="Replacement Rules", lines=5)
151
 
152
  with gr.Accordion("Fallback Rules", open=False):
153
- fallback_rules_text = gr.Textbox(label="Fallback Rules", lines=5)
154
 
155
  submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, replacement_rules_text, fallback_rules_text], outputs=output)
156
 
 
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  MODEL_FILENAME = "model.onnx"
19
  LABEL_FILENAME = "selected_tags.csv"
20
 
 
87
  images = [Image.open(file.name) for file in files]
88
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
89
 
90
+ # Predefined examples
91
+ predefined_filter_tags = "watermark" # This tag will be removed if detected
92
+ predefined_replacement_rules = "1boy -> 1girl" # "1boy" will be replaced with "1girl"
93
+ predefined_fallback_rules = "sad, happy -> smile" # If neither "sad" nor "happy" are present, add "smile"
94
+
95
+ # Combine predefined rules with user input
96
+ filter_tags = f"{predefined_filter_tags}, {filter_tags}".strip()
97
+ replacement_rules_text = f"{predefined_replacement_rules}\n{replacement_rules_text}".strip()
98
+ fallback_rules_text = f"{predefined_fallback_rules}\n{fallback_rules_text}".strip()
99
+
100
+ # Parse user-defined rules
101
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
102
  replacement_rules = parse_replacement_rules(replacement_rules_text)
103
  fallback_rules = parse_fallback_rules(fallback_rules_text)
104
 
105
+ # Generate formatted output
106
  prompts = []
107
  for general_tags, character_tags in results:
108
  general_tags = apply_replacements(general_tags, replacement_rules)
 
120
  args = parse_args()
121
  predictor = Predictor()
122
 
123
+ model_repos = ["SmilingWolf/wd-swinv2-tagger-v3", "SmilingWolf/wd-convnext-tagger-v3", "SmilingWolf/wd-vit-tagger-v3"]
124
 
125
  with gr.Blocks(title=TITLE) as demo:
126
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
 
131
  image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
132
 
133
  with gr.Accordion("Advanced Settings", open=False):
134
+ model_repo = gr.Dropdown(model_repos, value="SmilingWolf/wd-vit-tagger-v3", label="Select Model")
135
  general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
136
  character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
137
 
 
141
  output = gr.Textbox(label="Output", lines=10)
142
 
143
  with gr.Accordion("Tag Replacements", open=False):
144
+ replacement_rules_text = gr.Textbox(label="Replacement Rules", lines=5, value="1boy -> 1girl")
145
 
146
  with gr.Accordion("Fallback Rules", open=False):
147
+ fallback_rules_text = gr.Textbox(label="Fallback Rules", lines=5, value="sad, happy -> smile")
148
 
149
  submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, replacement_rules_text, fallback_rules_text], outputs=output)
150