wd-tagger / app.py
ura23's picture
Update app.py
7211311 verified
import argparse
import os
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
TITLE = "WaifuDiffusion Tagger"
DESCRIPTION = """
Demo for the WaifuDiffusion tagger models
"""
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Dataset v3 series of models:
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
VIT_MODEL_DSV3_REPO = "ura23/wd-vit-tagger-v3"
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
# Dataset v2 series of models:
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
# IdolSankaku series of models:
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
# Files to download from the repos
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--score-slider-step", type=float, default=0.05)
parser.add_argument("--score-general-threshold", type=float, default=0.25)
parser.add_argument("--score-character-threshold", type=float, default=1.0)
return parser.parse_args()
def load_labels(dataframe) -> list[str]:
tag_names = dataframe["name"].tolist()
general_indexes = list(np.where(dataframe["category"] == 0)[0])
character_indexes = list(np.where(dataframe["category"] == 4)[0])
return tag_names, general_indexes, character_indexes
class Predictor:
def __init__(self):
self.model_target_size = None
self.last_loaded_repo = None
def download_model(self, model_repo):
csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
return csv_path, model_path
def load_model(self, model_repo):
if model_repo == self.last_loaded_repo:
return
csv_path, model_path = self.download_model(model_repo)
tags_df = pd.read_csv(csv_path)
self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df)
model = rt.InferenceSession(model_path)
_, height, width, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.last_loaded_repo = model_repo
self.model = model
def prepare_image(self, image):
# Create a white canvas with the same size as the input image
canvas = Image.new("RGBA", image.size, (255, 255, 255))
# Ensure the input image has an alpha channel for compositing
if image.mode != "RGBA":
image = image.convert("RGBA")
# Composite the input image onto the canvas
canvas.alpha_composite(image)
# Convert to RGB (alpha channel is no longer needed)
image = canvas.convert("RGB")
# Resize the image to a square of size (model_target_size x model_target_size)
max_dim = max(image.size)
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
pad_left = (max_dim - image.width) // 2
pad_top = (max_dim - image.height) // 2
padded_image.paste(image, (pad_left, pad_top))
padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
# Convert the image to a NumPy array
image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
def predict(self, images, model_repo, general_thresh, character_thresh):
self.load_model(model_repo)
results = []
for image in images:
image = self.prepare_image(image)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image})[0]
labels = list(zip(self.tag_names, preds[0].astype(float)))
general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
results.append((general_res, character_res))
return results
def main():
args = parse_args()
predictor = Predictor()
model_repos = [
SWINV2_MODEL_DSV3_REPO,
CONV_MODEL_DSV3_REPO,
VIT_MODEL_DSV3_REPO,
VIT_LARGE_MODEL_DSV3_REPO,
EVA02_LARGE_MODEL_DSV3_REPO,
# ---
MOAT_MODEL_DSV2_REPO,
SWIN_MODEL_DSV2_REPO,
CONV_MODEL_DSV2_REPO,
CONV2_MODEL_DSV2_REPO,
VIT_MODEL_DSV2_REPO,
# ---
SWINV2_MODEL_IS_DSV1_REPO,
EVA02_LARGE_MODEL_IS_DSV1_REPO,
]
predefined_tags = ["loli",
"oppai_loli",
"2024",
"2023",
"2025",
"head-mounted_display",
"2022",
"2021",
"peeing",
"pee",
"round_eyewear",
"yellow-framed_eyewear",
"hetero",
"vaginal",
"straddling",
"girl_on_top",
"male_pubic_hair",
"cowgirl_position",
"happy_sex",
"vibrator_under_panties",
"vibrator_in_thighhighs",
"anal_beads",
"butt_plug",
"sex_toy",
"anal",
"object_insertion",
"dildo",
"anal_object_insertion",
"vaginal_object_insertion",
"semi-rimless_eyewear",
"red-framed_eyewear",
"under-rim_eyewear",
"3d_background",
"sample_watermark",
"onee-shota",
"incest",
"furry",
"can",
"drinking_can",
"holding_can",
"twitter_strip_game_(meme)",
"like_and_retweet",
"furry_female",
"realistic",
"egg_vibrator",
"tongue_piercing",
"handheld_game_console",
"game_controller",
"nintendo_switch",
"talking",
"swastika",
"character_name",
"vibrator",
"black-framed_eyewear",
"heterochromia",
"chibi",
"mini_person",
"controller",
"remote_control_vibrator",
"vibrator_under_clothes",
"thank_you",
"vibrator_cord",
"shota",
"cropped_legs",
"cropped_torso",
"traditional_media",
"color_guide",
"photorealistic",
"male_focus",
"black_babydoll",
"signature",
"web_address",
"censored_nipples",
"rhodes_island_logo_(arknights)",
"gothic_lolita",
"glasses",
"reference_inset",
"twitter_logo",
"mother_and_daughter",
"holding_controller",
"holding_game_controller",
"baby",
"heart_censor",
"pixiv_username",
"korean_text",
"pixiv_logo",
"greyscale_with_colored_background",
"water_bottle",
"body_writing",
"used_condom",
"multiple_condoms",
"condom_belt",
"holding_phone",
"multiple_views",
"phone",
"cellphone",
"zoom_layer",
"smartphone",
"lolita_hairband",
"lactation",
"otoko_no_ko",
"minigirl",
"babydoll",
"domino_mask",
"pixiv_id",
"qr_code",
"monochrome",
"trick_or_treat",
"happy_birthday",
"lolita_fashion",
"arrow_(symbol)",
"happy_new_year",
"dated",
"thought_bubble",
"greyscale",
"speech_bubble",
"mask",
"comic",
"bottle",
"holding_bottle",
"milk",
"milk_bottle",
"english_text",
"copyright_name",
"twitter_username",
"fanbox_username",
"patreon_username",
"patreon_logo",
"cover",
"weibo_logo",
"weibo_username",
"signature",
"content_rating",
"cover_page",
"doujin_cover",
"sex",
"artist_name",
"watermark",
"censored",
"bar_censor",
"blank_censor",
"blur_censor",
"light_censor",
"mosaic_censoring"]
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
submit = gr.Button(
value="Process Images", variant="primary"
)
image_files = gr.File(
file_types=["image"], label="Upload Images", file_count="multiple",
)
# Wrap the model selection and sliders in an Accordion
with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
model_repo = gr.Dropdown(
model_repos,
value=VIT_MODEL_DSV3_REPO,
label="Select Model",
)
general_thresh = gr.Slider(
0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
)
character_thresh = gr.Slider(
0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
)
filter_tags = gr.Textbox(
value=", ".join(predefined_tags),
label="Filter Tags (comma-separated)",
placeholder="Add tags to filter out (e.g., winter, red, from above)",
lines=9
)
with gr.Column():
output = gr.Textbox(label="Output", lines=10)
def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
images = [Image.open(file.name) for file in files]
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
# Parse filter tags
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
# Generate formatted output
prompts = []
for i, (general_tags, character_tags) in enumerate(results):
# Replace underscores with spaces for both character and general tags
character_part = ", ".join(
tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set
)
general_part = ", ".join(
tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
)
# Construct the prompt based on the presence of character_part
if character_part:
prompts.append(f"{character_part}, {general_part}")
else:
prompts.append(general_part)
# Join all prompts with blank lines
return "\n\n".join(prompts)
submit.click(
process_images,
inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
outputs=output
)
demo.queue(max_size=10)
demo.launch()
if __name__ == "__main__":
main()