import onnx import onnxruntime as ort import numpy as np import cv2 import os import csv VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" MODEL_DIR='sd/prompt_helper/model' ORT_SESSION=ort.InferenceSession(f'{MODEL_DIR}/model.onnx', providers=['CPUExecutionProvider']) #ORT_INPUT_NAME=ort_session.get_inputs()[0].name IMAGE_SIZE = 448 def img_preporation(img): bgr_img = np.array(img)[:, :, ::-1].copy() size = max(bgr_img.shape[0:2]) pad_x = size - bgr_img.shape[1] pad_y = size - bgr_img.shape[0] pad_l = pad_x // 2 pad_t = pad_y // 2 #add paddings to squaring image np.pad(bgr_img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) #adaptive resize interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 bgr_img = cv2.resize(bgr_img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) bgr_img = bgr_img.astype(np.float32) def get_help(img): # = model[0] ort_input_name = ORT_SESSION.get_inputs()[0].name with open(os.path.join(MODEL_DIR, "selected_tags.csv"), "r", encoding="utf-8") as f: reader = csv.reader(f) l = [row for row in reader] header = l[0] # tag_id,name,category,count rows = l[1:] assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" general_tags = [row[1] for row in rows[1:] if row[2] == "0"] character_tags = [row[1] for row in rows[1:] if row[2] == "4"] tag_freq = {} undesired_tags = ["transparent background"] #img = Image.open(image_path) preped_img = img_preporation(img) preped_img = np.expand_dims(preped_img, axis=0) # Run inference prob = ORT_SESSION.run(None, {ort_input_name: preped_img})[0][0] # Generate Tags combined_tags = [] general_tag_text = "" character_tag_text = "" remove_underscore = True caption_separator = ", " general_threshold = 0.35 character_threshold = 0.35 for i, p in enumerate(prob[4:]): if i < len(general_tags) and p >= general_threshold: tag_name = general_tags[i] if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^ tag_name = tag_name.replace("_", " ") if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) elif i >= len(general_tags) and p >= character_threshold: tag_name = character_tags[i - len(general_tags)] if remove_underscore and len(tag_name) > 3: tag_name = tag_name.replace("_", " ") if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name combined_tags.append(tag_name) # Remove leading comma if len(general_tag_text) > 0: general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: character_tag_text = character_tag_text[len(caption_separator) :] tag_text = caption_separator.join(combined_tags) return tag_text