import argparse import os from typing import Optional import io import gradio as gr import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd from PIL import Image from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse app = FastAPI() TITLE = "WaifuDiffusion Tagger" DESCRIPTION = "Demo for the WaifuDiffusion tagger models" # Dataset v3 models SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" # Dataset v2 models MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" kaomojis = ["0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||"] class Predictor: def __init__(self): self.model_target_size = None self.last_loaded_repo = None def download_model(self, model_repo): csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) return csv_path, model_path def load_model(self, model_repo): if model_repo == self.last_loaded_repo: return csv_path, model_path = self.download_model(model_repo) tags_df = pd.read_csv(csv_path) name_series = tags_df["name"] name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x) self.tag_names = name_series.tolist() self.rating_indexes = list(np.where(tags_df["category"] == 9)[0]) self.general_indexes = list(np.where(tags_df["category"] == 0)[0]) self.character_indexes = list(np.where(tags_df["category"] == 4)[0]) self.model = rt.InferenceSession(model_path) _, height, width, _ = self.model.get_inputs()[0].shape self.model_target_size = height self.last_loaded_repo = model_repo def prepare_image(self, image): canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") max_dim = max(image.size) pad_left = (max_dim - image.size[0]) // 2 pad_top = (max_dim - image.size[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != self.model_target_size: padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC) image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def predict(self, image, model_repo=SWINV2_MODEL_DSV3_REPO, threshold=0.05): self.load_model(model_repo) image = self.prepare_image(image) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name preds = self.model.run([label_name], {input_name: image})[0] labels = list(zip(self.tag_names, preds[0].astype(float))) general_names = [labels[i] for i in self.general_indexes] general_res = [x for x in general_names if x[1] > threshold] general_res = dict(general_res) sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True) return sorted_general, labels predictor = Predictor() @app.post("/tagging") async def tagging_endpoint( image: UploadFile = File(...), threshold: Optional[float] = Form(0.05) ): image_data = await image.read() pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA") sorted_general, _ = predictor.predict(pil_image, threshold=threshold) return JSONResponse(content={"tags": [x[0] for x in sorted_general]}) def ui_predict( image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, ): sorted_general, all_labels = predictor.predict(image, model_repo, general_thresh) # Ratings ratings = {all_labels[i][0]: all_labels[i][1] for i in predictor.rating_indexes} # Characters character_labels = [all_labels[i] for i in predictor.character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_labels]) character_thresh = max(0.15, np.mean(character_probs)) character_res = {x[0]: x[1] for x in character_labels if x[1] > character_thresh} # Format output sorted_general_strings = ", ".join(x[0] for x in sorted_general).replace("(", "\(").replace(")", "\)") return sorted_general_strings, ratings, character_res, dict(sorted_general) def create_demo(): with gr.Blocks(title=TITLE) as demo: gr.Markdown(f"

{TITLE}

") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(variant="panel"): image = gr.Image(type="pil", image_mode="RGBA", label="Input") model_repo = gr.Dropdown( choices=[ SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO ], value=SWINV2_MODEL_DSV3_REPO, label="Model" ) with gr.Row(): general_thresh = gr.Slider(0, 1, value=0.35, step=0.05, label="General Tags Threshold") general_mcut = gr.Checkbox(value=False, label="Use MCut threshold") with gr.Row(): character_thresh = gr.Slider(0, 1, value=0.85, step=0.05, label="Character Tags Threshold") character_mcut = gr.Checkbox(value=False, label="Use MCut threshold") submit = gr.Button(value="Submit", variant="primary") with gr.Column(variant="panel"): text_output = gr.Textbox(label="Output (string)") rating_output = gr.Label(label="Rating") character_output = gr.Label(label="Characters") general_output = gr.Label(label="Tags") submit.click( ui_predict, inputs=[image, model_repo, general_thresh, general_mcut, character_thresh, character_mcut], outputs=[text_output, rating_output, character_output, general_output] ) demo.queue(max_size=10) return demo app = gr.mount_gradio_app(app, create_demo(), path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)