File size: 3,304 Bytes
a509e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465f026
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import onnx
import onnxruntime as ort
import numpy as np
import cv2
import os
import csv

VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"

MODEL_DIR='sd/prompt_helper/model'
ORT_SESSION=ort.InferenceSession(f'{MODEL_DIR}/model.onnx', providers=['CPUExecutionProvider'])
#ORT_INPUT_NAME=ort_session.get_inputs()[0].name

IMAGE_SIZE = 448

def img_preporation(img):
    bgr_img = np.array(img)[:, :, ::-1].copy()
    size = max(bgr_img.shape[0:2])
    pad_x = size - bgr_img.shape[1]
    pad_y = size - bgr_img.shape[0]
    pad_l = pad_x // 2
    pad_t = pad_y // 2

    #add paddings to squaring image
    np.pad(bgr_img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

    #adaptive resize
    interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
    bgr_img = cv2.resize(bgr_img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)

    return bgr_img.astype(np.float32)


def get_help(img):
    # = model[0]
    ort_input_name = ORT_SESSION.get_inputs()[0].name

    with open(os.path.join(MODEL_DIR, "selected_tags.csv"), "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        l = [row for row in reader]
        header = l[0]  # tag_id,name,category,count
        rows = l[1:]
    assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"

    general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
    character_tags = [row[1] for row in rows[1:] if row[2] == "4"]

    tag_freq = {}
    undesired_tags = ["transparent background"]

    #img = Image.open(image_path)
    preped_img = img_preporation(img)
    preped_img = np.expand_dims(preped_img, axis=0)

    # Run inference
    prob = ORT_SESSION.run(None, {ort_input_name: preped_img})[0][0]
    # Generate Tags
    combined_tags = []
    general_tag_text = ""
    character_tag_text = ""
    remove_underscore = True
    caption_separator = ", "
    general_threshold = 0.35
    character_threshold = 0.35

    for i, p in enumerate(prob[4:]):
        if i < len(general_tags) and p >= general_threshold:
            tag_name = general_tags[i]
            if remove_underscore and len(tag_name) > 3:  # ignore emoji tags like >_< and ^_^
                tag_name = tag_name.replace("_", " ")

            if tag_name not in undesired_tags:
                tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                general_tag_text += caption_separator + tag_name
                combined_tags.append(tag_name)
        elif i >= len(general_tags) and p >= character_threshold:
            tag_name = character_tags[i - len(general_tags)]
            if remove_underscore and len(tag_name) > 3:
                tag_name = tag_name.replace("_", " ")

            if tag_name not in undesired_tags:
                tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                character_tag_text += caption_separator + tag_name
                combined_tags.append(tag_name)

    # Remove leading comma
    if len(general_tag_text) > 0:
        general_tag_text = general_tag_text[len(caption_separator) :]
    if len(character_tag_text) > 0:
        character_tag_text = character_tag_text[len(caption_separator) :]
    tag_text = caption_separator.join(combined_tags)
    return tag_text