Spaces:
Runtime error
Runtime error
from os import getenv | |
from pathlib import Path | |
from typing import Optional | |
import gradio as gr | |
import numpy as np | |
import onnxruntime as rt | |
from PIL import Image | |
from tagger.common import LabelData, load_labels, preprocess_image | |
from tagger.model import create_session | |
HF_TOKEN = getenv("HF_TOKEN", None) | |
WORK_DIR = Path.cwd().resolve() | |
MODEL_VARIANTS: dict[str, str] = { | |
"MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2", | |
"SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2", | |
"ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2", | |
"ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", | |
"ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2", | |
} | |
# allowed extensions | |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
# model input shape | |
IMAGE_SIZE = 448 | |
example_images = sorted( | |
[ | |
str(x.relative_to(WORK_DIR)) | |
for x in WORK_DIR.joinpath("examples").iterdir() | |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS | |
] | |
) | |
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k, _ in MODEL_VARIANTS.items()} | |
def load_model(variant: str) -> rt.InferenceSession: | |
global loaded_models | |
# resolve the repo name | |
model_repo = MODEL_VARIANTS.get(variant, None) | |
if model_repo is None: | |
raise ValueError(f"Unknown model variant: {variant}") | |
if loaded_models.get(variant, None) is None: | |
# save model to cache | |
loaded_models[variant] = create_session(model_repo, token=HF_TOKEN) | |
return loaded_models[variant] | |
def predict( | |
image: Image.Image, | |
variant: str, | |
general_threshold: float = 0.35, | |
character_threshold: float = 0.85, | |
): | |
# Load model | |
model: rt.InferenceSession = load_model(variant) | |
# load labels | |
labels: LabelData = load_labels() | |
# get input size and name | |
_, h, w, _ = model.get_inputs()[0].shape | |
input_name = model.get_inputs()[0].name | |
output_name = model.get_outputs()[0].name | |
# preprocess image | |
image = preprocess_image(image, (h, w)) | |
# turn into BGR24 numpy array of N,H,W,C since thats what these want | |
inputs = image.convert("RGB").convert("BGR;24") | |
inputs = np.array(inputs).astype(np.float32) | |
inputs = np.expand_dims(inputs, axis=0) | |
# Run the ONNX model | |
probs = model.run([output_name], {input_name: inputs}) | |
# Convert indices+probs to labels | |
probs = list(zip(labels.names, probs[0][0].astype(float))) | |
# First 4 labels are actually ratings | |
rating_labels = dict([probs[i] for i in labels.rating]) | |
# General labels, pick any where prediction confidence > threshold | |
gen_labels = [probs[i] for i in labels.general] | |
gen_labels = dict([x for x in gen_labels if x[1] > general_threshold]) | |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Character labels, pick any where prediction confidence > threshold | |
char_labels = [probs[i] for i in labels.character] | |
char_labels = dict([x for x in char_labels if x[1] > character_threshold]) | |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Combine general and character labels, sort by confidence | |
combined_names = [x for x in gen_labels] | |
combined_names.extend([x for x in char_labels]) | |
# Convert to a string suitable for use as a training caption | |
caption = ", ".join(combined_names) | |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") | |
return image, caption, booru, rating_labels, char_labels, gen_labels | |
with gr.Blocks(title="pi-chan's tagger") as demo: | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
img_input = gr.Image( | |
label="Input", | |
type="pil", | |
image_mode="RGB", | |
sources=["upload", "clipboard"], | |
) | |
variant = gr.Radio(choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="MOAT") | |
gen_thresh = gr.Slider(0.0, 1.0, value=0.35, label="General Tag Threshold") | |
char_thresh = gr.Slider(0.0, 1.0, value=0.85, label="Character Tag Threshold") | |
show_processed = gr.Checkbox(label="Show Preprocessed", value=False) | |
with gr.Row(): | |
submit = gr.Button(value="Submit", variant="primary", size="lg") | |
clear = gr.ClearButton( | |
components=[], | |
variant="secondary", | |
size="lg", | |
) | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=[ | |
[imgpath, var, 0.35, 0.85] | |
for imgpath in example_images | |
for var in ["MOAT", "ConvNeXTv2"] | |
], | |
inputs=[img_input, variant, gen_thresh, char_thresh], | |
) | |
with gr.Column(): | |
img_output = gr.Image(label="Preprocessed", type="pil", image_mode="RGB", scale=1, visible=False) | |
with gr.Group(): | |
tags_string = gr.Textbox( | |
label="Caption", placeholder="Caption will appear here", show_copy_button=True | |
) | |
tags_booru = gr.Textbox( | |
label="Tags", placeholder="Tag string will appear here", show_copy_button=True | |
) | |
rating = gr.Label(label="Rating") | |
character = gr.Label(label="Character") | |
general = gr.Label(label="General") | |
# tell clear button which components to clear | |
clear.add([img_input, img_output, tags_string, rating, character, general]) | |
# show/hide processed image | |
def on_select_show_processed(evt: gr.SelectData): | |
return gr.update(visible=evt.selected) | |
show_processed.select(on_select_show_processed, inputs=[], outputs=[img_output]) | |
submit.click( | |
predict, | |
inputs=[img_input, variant, gen_thresh, char_thresh], | |
outputs=[img_output, tags_string, tags_booru, rating, character, general], | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=10) | |
demo.launch(server_name="0.0.0.0", server_port=7871) | |