Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,9 +16,9 @@ Demo for the WaifuDiffusion tagger models
|
|
16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
17 |
|
18 |
# Dataset v3 series of models:
|
19 |
-
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
20 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
21 |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
|
|
22 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
23 |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
24 |
|
@@ -123,9 +123,9 @@ def main():
|
|
123 |
predictor = Predictor()
|
124 |
|
125 |
model_repos = [
|
126 |
-
VIT_MODEL_DSV3_REPO,
|
127 |
SWINV2_MODEL_DSV3_REPO,
|
128 |
CONV_MODEL_DSV3_REPO,
|
|
|
129 |
VIT_LARGE_MODEL_DSV3_REPO,
|
130 |
EVA02_LARGE_MODEL_DSV3_REPO,
|
131 |
# ---
|
@@ -177,12 +177,7 @@ def main():
|
|
177 |
"blank_censor",
|
178 |
"blur_censor",
|
179 |
"light_censor",
|
180 |
-
"mosaic_censoring"]
|
181 |
-
|
182 |
-
predefined_tags2 = [
|
183 |
-
"big, small:medium", # If either "big" or "small" is missing, add "medium"
|
184 |
-
"small hand, large hand:medium hand" # If either "small hand" or "large hand" is missing, add "medium hand"
|
185 |
-
]
|
186 |
|
187 |
with gr.Blocks(title=TITLE) as demo:
|
188 |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
@@ -213,33 +208,21 @@ def main():
|
|
213 |
placeholder="Add tags to filter out (e.g., winter, red, from above)",
|
214 |
lines=5
|
215 |
)
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
lines=3
|
221 |
-
)
|
222 |
-
submit = gr.Button(
|
223 |
-
value="Process Images", variant="primary"
|
224 |
-
)
|
225 |
|
226 |
with gr.Column():
|
227 |
output = gr.Textbox(label="Output", lines=10)
|
228 |
|
229 |
-
def process_images(files, model_repo, general_thresh, character_thresh, filter_tags
|
230 |
images = [Image.open(file.name) for file in files]
|
231 |
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
|
232 |
-
|
233 |
# Parse filter tags
|
234 |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
|
235 |
-
|
236 |
-
# Parse custom tags and their fallback pairs
|
237 |
-
fallback_tags = {}
|
238 |
-
for pair in custom_tags_input.split(","):
|
239 |
-
if ":" in pair:
|
240 |
-
tag, fallback = pair.split(":")
|
241 |
-
fallback_tags[tag.strip().lower()] = fallback.strip().lower()
|
242 |
-
|
243 |
# Generate formatted output
|
244 |
prompts = []
|
245 |
for i, (general_tags, character_tags) in enumerate(results):
|
@@ -250,30 +233,24 @@ def main():
|
|
250 |
general_part = ", ".join(
|
251 |
tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
|
252 |
)
|
253 |
-
|
254 |
-
# Check if custom tags are missing and apply fallback tags
|
255 |
-
all_tags = set(general_tags + character_tags)
|
256 |
-
for tag, fallback in fallback_tags.items():
|
257 |
-
if tag not in all_tags:
|
258 |
-
all_tags.add(fallback)
|
259 |
-
|
260 |
-
# Construct the final prompt
|
261 |
-
final_tags = ", ".join(tag.replace('_', ' ') for tag in all_tags if tag.lower() not in filter_set)
|
262 |
-
prompts.append(final_tags)
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
# Join all prompts with blank lines
|
265 |
return "\n\n".join(prompts)
|
266 |
|
267 |
-
|
268 |
submit.click(
|
269 |
process_images,
|
270 |
-
inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags
|
271 |
outputs=output
|
272 |
)
|
273 |
|
274 |
-
|
275 |
demo.queue(max_size=10)
|
276 |
demo.launch()
|
277 |
|
278 |
if __name__ == "__main__":
|
279 |
-
main()
|
|
|
16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
17 |
|
18 |
# Dataset v3 series of models:
|
|
|
19 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
20 |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
21 |
+
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
22 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
23 |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
24 |
|
|
|
123 |
predictor = Predictor()
|
124 |
|
125 |
model_repos = [
|
|
|
126 |
SWINV2_MODEL_DSV3_REPO,
|
127 |
CONV_MODEL_DSV3_REPO,
|
128 |
+
VIT_MODEL_DSV3_REPO,
|
129 |
VIT_LARGE_MODEL_DSV3_REPO,
|
130 |
EVA02_LARGE_MODEL_DSV3_REPO,
|
131 |
# ---
|
|
|
177 |
"blank_censor",
|
178 |
"blur_censor",
|
179 |
"light_censor",
|
180 |
+
"mosaic_censoring"]
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
with gr.Blocks(title=TITLE) as demo:
|
183 |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
|
|
208 |
placeholder="Add tags to filter out (e.g., winter, red, from above)",
|
209 |
lines=5
|
210 |
)
|
211 |
+
|
212 |
+
submit = gr.Button(
|
213 |
+
value="Process Images", variant="primary"
|
214 |
+
)
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
with gr.Column():
|
217 |
output = gr.Textbox(label="Output", lines=10)
|
218 |
|
219 |
+
def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
|
220 |
images = [Image.open(file.name) for file in files]
|
221 |
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
|
222 |
+
|
223 |
# Parse filter tags
|
224 |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
|
225 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
# Generate formatted output
|
227 |
prompts = []
|
228 |
for i, (general_tags, character_tags) in enumerate(results):
|
|
|
233 |
general_part = ", ".join(
|
234 |
tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
|
235 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
+
# Construct the prompt based on the presence of character_part
|
238 |
+
if character_part:
|
239 |
+
prompts.append(f"{character_part}, {general_part}")
|
240 |
+
else:
|
241 |
+
prompts.append(general_part)
|
242 |
+
|
243 |
# Join all prompts with blank lines
|
244 |
return "\n\n".join(prompts)
|
245 |
|
|
|
246 |
submit.click(
|
247 |
process_images,
|
248 |
+
inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
|
249 |
outputs=output
|
250 |
)
|
251 |
|
|
|
252 |
demo.queue(max_size=10)
|
253 |
demo.launch()
|
254 |
|
255 |
if __name__ == "__main__":
|
256 |
+
main()
|