Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,24 +15,6 @@ Demo for the WaifuDiffusion tagger models
|
|
15 |
|
16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
17 |
|
18 |
-
# Dataset v3 series of models:
|
19 |
-
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
20 |
-
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
21 |
-
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
22 |
-
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
23 |
-
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
24 |
-
|
25 |
-
# Dataset v2 series of models:
|
26 |
-
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
27 |
-
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
28 |
-
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
29 |
-
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
30 |
-
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
31 |
-
|
32 |
-
# IdolSankaku series of models:
|
33 |
-
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
|
34 |
-
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
|
35 |
-
|
36 |
MODEL_FILENAME = "model.onnx"
|
37 |
LABEL_FILENAME = "selected_tags.csv"
|
38 |
|
@@ -105,10 +87,22 @@ def process_images(files, model_repo, general_thresh, character_thresh, filter_t
|
|
105 |
images = [Image.open(file.name) for file in files]
|
106 |
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
|
109 |
replacement_rules = parse_replacement_rules(replacement_rules_text)
|
110 |
fallback_rules = parse_fallback_rules(fallback_rules_text)
|
111 |
|
|
|
112 |
prompts = []
|
113 |
for general_tags, character_tags in results:
|
114 |
general_tags = apply_replacements(general_tags, replacement_rules)
|
@@ -126,7 +120,7 @@ def process_images(files, model_repo, general_thresh, character_thresh, filter_t
|
|
126 |
args = parse_args()
|
127 |
predictor = Predictor()
|
128 |
|
129 |
-
model_repos = [
|
130 |
|
131 |
with gr.Blocks(title=TITLE) as demo:
|
132 |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
@@ -137,7 +131,7 @@ with gr.Blocks(title=TITLE) as demo:
|
|
137 |
image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
|
138 |
|
139 |
with gr.Accordion("Advanced Settings", open=False):
|
140 |
-
model_repo = gr.Dropdown(model_repos, value=
|
141 |
general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
|
142 |
character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
|
143 |
|
@@ -147,10 +141,10 @@ with gr.Blocks(title=TITLE) as demo:
|
|
147 |
output = gr.Textbox(label="Output", lines=10)
|
148 |
|
149 |
with gr.Accordion("Tag Replacements", open=False):
|
150 |
-
replacement_rules_text = gr.Textbox(label="Replacement Rules", lines=5)
|
151 |
|
152 |
with gr.Accordion("Fallback Rules", open=False):
|
153 |
-
fallback_rules_text = gr.Textbox(label="Fallback Rules", lines=5)
|
154 |
|
155 |
submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, replacement_rules_text, fallback_rules_text], outputs=output)
|
156 |
|
|
|
15 |
|
16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
MODEL_FILENAME = "model.onnx"
|
19 |
LABEL_FILENAME = "selected_tags.csv"
|
20 |
|
|
|
87 |
images = [Image.open(file.name) for file in files]
|
88 |
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
|
89 |
|
90 |
+
# Predefined examples
|
91 |
+
predefined_filter_tags = "watermark" # This tag will be removed if detected
|
92 |
+
predefined_replacement_rules = "1boy -> 1girl" # "1boy" will be replaced with "1girl"
|
93 |
+
predefined_fallback_rules = "sad, happy -> smile" # If neither "sad" nor "happy" are present, add "smile"
|
94 |
+
|
95 |
+
# Combine predefined rules with user input
|
96 |
+
filter_tags = f"{predefined_filter_tags}, {filter_tags}".strip()
|
97 |
+
replacement_rules_text = f"{predefined_replacement_rules}\n{replacement_rules_text}".strip()
|
98 |
+
fallback_rules_text = f"{predefined_fallback_rules}\n{fallback_rules_text}".strip()
|
99 |
+
|
100 |
+
# Parse user-defined rules
|
101 |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
|
102 |
replacement_rules = parse_replacement_rules(replacement_rules_text)
|
103 |
fallback_rules = parse_fallback_rules(fallback_rules_text)
|
104 |
|
105 |
+
# Generate formatted output
|
106 |
prompts = []
|
107 |
for general_tags, character_tags in results:
|
108 |
general_tags = apply_replacements(general_tags, replacement_rules)
|
|
|
120 |
args = parse_args()
|
121 |
predictor = Predictor()
|
122 |
|
123 |
+
model_repos = ["SmilingWolf/wd-swinv2-tagger-v3", "SmilingWolf/wd-convnext-tagger-v3", "SmilingWolf/wd-vit-tagger-v3"]
|
124 |
|
125 |
with gr.Blocks(title=TITLE) as demo:
|
126 |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
|
|
131 |
image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
|
132 |
|
133 |
with gr.Accordion("Advanced Settings", open=False):
|
134 |
+
model_repo = gr.Dropdown(model_repos, value="SmilingWolf/wd-vit-tagger-v3", label="Select Model")
|
135 |
general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
|
136 |
character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
|
137 |
|
|
|
141 |
output = gr.Textbox(label="Output", lines=10)
|
142 |
|
143 |
with gr.Accordion("Tag Replacements", open=False):
|
144 |
+
replacement_rules_text = gr.Textbox(label="Replacement Rules", lines=5, value="1boy -> 1girl")
|
145 |
|
146 |
with gr.Accordion("Fallback Rules", open=False):
|
147 |
+
fallback_rules_text = gr.Textbox(label="Fallback Rules", lines=5, value="sad, happy -> smile")
|
148 |
|
149 |
submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, replacement_rules_text, fallback_rules_text], outputs=output)
|
150 |
|