Spaces:
Running
Running
import argparse | |
import os | |
import gradio as gr | |
import huggingface_hub | |
import numpy as np | |
import onnxruntime as rt | |
import pandas as pd | |
from PIL import Image | |
TITLE = "WaifuDiffusion Tagger" | |
DESCRIPTION = """ | |
Demo for the WaifuDiffusion tagger models | |
""" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
# Dataset v3 series of models: | |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" | |
VIT_MODEL_DSV3_REPO = "ura23/wd-vit-tagger-v3" | |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" | |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
# Dataset v2 series of models: | |
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" | |
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | |
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | |
# IdolSankaku series of models: | |
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1" | |
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1" | |
# Files to download from the repos | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--score-slider-step", type=float, default=0.05) | |
parser.add_argument("--score-general-threshold", type=float, default=0.25) | |
parser.add_argument("--score-character-threshold", type=float, default=1.0) | |
return parser.parse_args() | |
def load_labels(dataframe) -> list[str]: | |
tag_names = dataframe["name"].tolist() | |
general_indexes = list(np.where(dataframe["category"] == 0)[0]) | |
character_indexes = list(np.where(dataframe["category"] == 4)[0]) | |
return tag_names, general_indexes, character_indexes | |
class Predictor: | |
def __init__(self): | |
self.model_target_size = None | |
self.last_loaded_repo = None | |
def download_model(self, model_repo): | |
csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN) | |
model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN) | |
return csv_path, model_path | |
def load_model(self, model_repo): | |
if model_repo == self.last_loaded_repo: | |
return | |
csv_path, model_path = self.download_model(model_repo) | |
tags_df = pd.read_csv(csv_path) | |
self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df) | |
model = rt.InferenceSession(model_path) | |
_, height, width, _ = model.get_inputs()[0].shape | |
self.model_target_size = height | |
self.last_loaded_repo = model_repo | |
self.model = model | |
def prepare_image(self, image): | |
# Create a white canvas with the same size as the input image | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
# Ensure the input image has an alpha channel for compositing | |
if image.mode != "RGBA": | |
image = image.convert("RGBA") | |
# Composite the input image onto the canvas | |
canvas.alpha_composite(image) | |
# Convert to RGB (alpha channel is no longer needed) | |
image = canvas.convert("RGB") | |
# Resize the image to a square of size (model_target_size x model_target_size) | |
max_dim = max(image.size) | |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
pad_left = (max_dim - image.width) // 2 | |
pad_top = (max_dim - image.height) // 2 | |
padded_image.paste(image, (pad_left, pad_top)) | |
padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC) | |
# Convert the image to a NumPy array | |
image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1] | |
return np.expand_dims(image_array, axis=0) | |
def predict(self, images, model_repo, general_thresh, character_thresh): | |
self.load_model(model_repo) | |
results = [] | |
for image in images: | |
image = self.prepare_image(image) | |
input_name = self.model.get_inputs()[0].name | |
label_name = self.model.get_outputs()[0].name | |
preds = self.model.run([label_name], {input_name: image})[0] | |
labels = list(zip(self.tag_names, preds[0].astype(float))) | |
general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh] | |
character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh] | |
results.append((general_res, character_res)) | |
return results | |
def main(): | |
args = parse_args() | |
predictor = Predictor() | |
model_repos = [ | |
SWINV2_MODEL_DSV3_REPO, | |
CONV_MODEL_DSV3_REPO, | |
VIT_MODEL_DSV3_REPO, | |
VIT_LARGE_MODEL_DSV3_REPO, | |
EVA02_LARGE_MODEL_DSV3_REPO, | |
# --- | |
MOAT_MODEL_DSV2_REPO, | |
SWIN_MODEL_DSV2_REPO, | |
CONV_MODEL_DSV2_REPO, | |
CONV2_MODEL_DSV2_REPO, | |
VIT_MODEL_DSV2_REPO, | |
# --- | |
SWINV2_MODEL_IS_DSV1_REPO, | |
EVA02_LARGE_MODEL_IS_DSV1_REPO, | |
] | |
predefined_tags = ["loli", | |
"oppai_loli", | |
"2024", | |
"2023", | |
"2025", | |
"head-mounted_display", | |
"2022", | |
"2021", | |
"peeing", | |
"pee", | |
"round_eyewear", | |
"yellow-framed_eyewear", | |
"hetero", | |
"vaginal", | |
"straddling", | |
"girl_on_top", | |
"male_pubic_hair", | |
"cowgirl_position", | |
"happy_sex", | |
"vibrator_under_panties", | |
"vibrator_in_thighhighs", | |
"anal_beads", | |
"butt_plug", | |
"sex_toy", | |
"anal", | |
"object_insertion", | |
"dildo", | |
"anal_object_insertion", | |
"vaginal_object_insertion", | |
"semi-rimless_eyewear", | |
"red-framed_eyewear", | |
"under-rim_eyewear", | |
"3d_background", | |
"sample_watermark", | |
"onee-shota", | |
"incest", | |
"furry", | |
"can", | |
"drinking_can", | |
"holding_can", | |
"twitter_strip_game_(meme)", | |
"like_and_retweet", | |
"furry_female", | |
"realistic", | |
"egg_vibrator", | |
"tongue_piercing", | |
"handheld_game_console", | |
"game_controller", | |
"nintendo_switch", | |
"talking", | |
"swastika", | |
"character_name", | |
"vibrator", | |
"black-framed_eyewear", | |
"heterochromia", | |
"chibi", | |
"mini_person", | |
"controller", | |
"remote_control_vibrator", | |
"vibrator_under_clothes", | |
"thank_you", | |
"vibrator_cord", | |
"shota", | |
"cropped_legs", | |
"cropped_torso", | |
"traditional_media", | |
"color_guide", | |
"photorealistic", | |
"male_focus", | |
"black_babydoll", | |
"signature", | |
"web_address", | |
"censored_nipples", | |
"rhodes_island_logo_(arknights)", | |
"gothic_lolita", | |
"glasses", | |
"reference_inset", | |
"twitter_logo", | |
"mother_and_daughter", | |
"holding_controller", | |
"holding_game_controller", | |
"baby", | |
"heart_censor", | |
"pixiv_username", | |
"korean_text", | |
"pixiv_logo", | |
"greyscale_with_colored_background", | |
"water_bottle", | |
"body_writing", | |
"used_condom", | |
"multiple_condoms", | |
"condom_belt", | |
"holding_phone", | |
"multiple_views", | |
"phone", | |
"cellphone", | |
"zoom_layer", | |
"smartphone", | |
"lolita_hairband", | |
"lactation", | |
"otoko_no_ko", | |
"minigirl", | |
"babydoll", | |
"domino_mask", | |
"pixiv_id", | |
"qr_code", | |
"monochrome", | |
"trick_or_treat", | |
"happy_birthday", | |
"lolita_fashion", | |
"arrow_(symbol)", | |
"happy_new_year", | |
"dated", | |
"thought_bubble", | |
"greyscale", | |
"speech_bubble", | |
"mask", | |
"comic", | |
"bottle", | |
"holding_bottle", | |
"milk", | |
"milk_bottle", | |
"english_text", | |
"copyright_name", | |
"twitter_username", | |
"fanbox_username", | |
"patreon_username", | |
"patreon_logo", | |
"cover", | |
"weibo_logo", | |
"weibo_username", | |
"signature", | |
"content_rating", | |
"cover_page", | |
"doujin_cover", | |
"sex", | |
"artist_name", | |
"watermark", | |
"censored", | |
"bar_censor", | |
"blank_censor", | |
"blur_censor", | |
"light_censor", | |
"mosaic_censoring"] | |
with gr.Blocks(title=TITLE) as demo: | |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>") | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
submit = gr.Button( | |
value="Process Images", variant="primary" | |
) | |
image_files = gr.File( | |
file_types=["image"], label="Upload Images", file_count="multiple", | |
) | |
# Wrap the model selection and sliders in an Accordion | |
with gr.Accordion("Advanced Settings", open=False): # Collapsible by default | |
model_repo = gr.Dropdown( | |
model_repos, | |
value=VIT_MODEL_DSV3_REPO, | |
label="Select Model", | |
) | |
general_thresh = gr.Slider( | |
0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold" | |
) | |
character_thresh = gr.Slider( | |
0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold" | |
) | |
filter_tags = gr.Textbox( | |
value=", ".join(predefined_tags), | |
label="Filter Tags (comma-separated)", | |
placeholder="Add tags to filter out (e.g., winter, red, from above)", | |
lines=9 | |
) | |
with gr.Column(): | |
output = gr.Textbox(label="Output", lines=10) | |
def process_images(files, model_repo, general_thresh, character_thresh, filter_tags): | |
images = [Image.open(file.name) for file in files] | |
results = predictor.predict(images, model_repo, general_thresh, character_thresh) | |
# Parse filter tags | |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(",")) | |
# Generate formatted output | |
prompts = [] | |
for i, (general_tags, character_tags) in enumerate(results): | |
# Replace underscores with spaces for both character and general tags | |
character_part = ", ".join( | |
tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set | |
) | |
general_part = ", ".join( | |
tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set | |
) | |
# Construct the prompt based on the presence of character_part | |
if character_part: | |
prompts.append(f"{character_part}, {general_part}") | |
else: | |
prompts.append(general_part) | |
# Join all prompts with blank lines | |
return "\n\n".join(prompts) | |
submit.click( | |
process_images, | |
inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags], | |
outputs=output | |
) | |
demo.queue(max_size=10) | |
demo.launch() | |
if __name__ == "__main__": | |
main() |