import argparse import os import gradio as gr import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd from PIL import Image TITLE = "WaifuDiffusion Tagger" DESCRIPTION = """ Demo for the WaifuDiffusion tagger models """ HF_TOKEN = os.environ.get("HF_TOKEN", "") # Dataset v3 series of models: SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" VIT_MODEL_DSV3_REPO = "ura23/wd-vit-tagger-v3" VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" # Dataset v2 series of models: MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" # IdolSankaku series of models: EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1" SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1" # Files to download from the repos MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--score-slider-step", type=float, default=0.05) parser.add_argument("--score-general-threshold", type=float, default=0.25) parser.add_argument("--score-character-threshold", type=float, default=1.0) return parser.parse_args() def load_labels(dataframe) -> list[str]: tag_names = dataframe["name"].tolist() general_indexes = list(np.where(dataframe["category"] == 0)[0]) character_indexes = list(np.where(dataframe["category"] == 4)[0]) return tag_names, general_indexes, character_indexes class Predictor: def __init__(self): self.model_target_size = None self.last_loaded_repo = None def download_model(self, model_repo): csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN) model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN) return csv_path, model_path def load_model(self, model_repo): if model_repo == self.last_loaded_repo: return csv_path, model_path = self.download_model(model_repo) tags_df = pd.read_csv(csv_path) self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df) model = rt.InferenceSession(model_path) _, height, width, _ = model.get_inputs()[0].shape self.model_target_size = height self.last_loaded_repo = model_repo self.model = model def prepare_image(self, image): # Create a white canvas with the same size as the input image canvas = Image.new("RGBA", image.size, (255, 255, 255)) # Ensure the input image has an alpha channel for compositing if image.mode != "RGBA": image = image.convert("RGBA") # Composite the input image onto the canvas canvas.alpha_composite(image) # Convert to RGB (alpha channel is no longer needed) image = canvas.convert("RGB") # Resize the image to a square of size (model_target_size x model_target_size) max_dim = max(image.size) padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) pad_left = (max_dim - image.width) // 2 pad_top = (max_dim - image.height) // 2 padded_image.paste(image, (pad_left, pad_top)) padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC) # Convert the image to a NumPy array image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1] return np.expand_dims(image_array, axis=0) def predict(self, images, model_repo, general_thresh, character_thresh): self.load_model(model_repo) results = [] for image in images: image = self.prepare_image(image) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name preds = self.model.run([label_name], {input_name: image})[0] labels = list(zip(self.tag_names, preds[0].astype(float))) general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh] character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh] results.append((general_res, character_res)) return results def main(): args = parse_args() predictor = Predictor() model_repos = [ SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, EVA02_LARGE_MODEL_DSV3_REPO, # --- MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO, # --- SWINV2_MODEL_IS_DSV1_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, ] predefined_tags = ["loli", "oppai_loli", "2024", "2023", "2025", "head-mounted_display", "2022", "muscular_female", "muscular", "abs", "2021", "peeing", "pee", "round_eyewear", "yellow-framed_eyewear", "hetero", "vaginal", "straddling", "girl_on_top", "male_pubic_hair", "cowgirl_position", "happy_sex", "vibrator_under_panties", "vibrator_in_thighhighs", "anal_beads", "butt_plug", "sex_toy", "anal", "object_insertion", "dildo", "anal_object_insertion", "vaginal_object_insertion", "semi-rimless_eyewear", "red-framed_eyewear", "under-rim_eyewear", "3d_background", "sample_watermark", "onee-shota", "incest", "furry", "can", "drinking_can", "holding_can", "twitter_strip_game_(meme)", "like_and_retweet", "furry_female", "realistic", "egg_vibrator", "tongue_piercing", "handheld_game_console", "game_controller", "nintendo_switch", "talking", "swastika", "character_name", "vibrator", "black-framed_eyewear", "heterochromia", "chibi", "mini_person", "controller", "remote_control_vibrator", "vibrator_under_clothes", "thank_you", "vibrator_cord", "shota", "cropped_legs", "cropped_torso", "traditional_media", "color_guide", "photorealistic", "male_focus", "black_babydoll", "signature", "web_address", "censored_nipples", "rhodes_island_logo_(arknights)", "gothic_lolita", "glasses", "reference_inset", "twitter_logo", "mother_and_daughter", "holding_controller", "holding_game_controller", "baby", "heart_censor", "pixiv_username", "korean_text", "pixiv_logo", "greyscale_with_colored_background", "water_bottle", "body_writing", "used_condom", "multiple_condoms", "condom_belt", "holding_phone", "multiple_views", "phone", "cellphone", "zoom_layer", "smartphone", "lolita_hairband", "lactation", "otoko_no_ko", "minigirl", "babydoll", "domino_mask", "pixiv_id", "qr_code", "monochrome", "trick_or_treat", "happy_birthday", "lolita_fashion", "arrow_(symbol)", "happy_new_year", "dated", "thought_bubble", "greyscale", "speech_bubble", "mask", "comic", "bottle", "holding_bottle", "milk", "milk_bottle", "english_text", "copyright_name", "twitter_username", "fanbox_username", "patreon_username", "patreon_logo", "cover", "weibo_logo", "weibo_username", "signature", "content_rating", "cover_page", "doujin_cover", "sex", "artist_name", "watermark", "censored", "bar_censor", "blank_censor", "blur_censor", "light_censor", "mosaic_censoring"] with gr.Blocks(title=TITLE) as demo: gr.Markdown(f"