File size: 5,305 Bytes
c9cc441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a157ab
c9cc441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
import csv
import os
import json

from PIL import Image
import cv2
import numpy as np
from tensorflow.keras.layers import TFSMLayer
from huggingface_hub import hf_hub_download
from pathlib import Path

# from wd14 tagger
IMAGE_SIZE = 448

# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]

def preprocess_image(image):
    image = np.array(image)
    image = image[:, :, ::-1]  # RGB->BGR

    # pad to square
    size = max(image.shape[0:2])
    pad_x = size - image.shape[1]
    pad_y = size - image.shape[0]
    pad_l = pad_x // 2
    pad_t = pad_y // 2
    image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

    interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
    image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)

    image = image.astype(np.float32)
    return image


def load_wd14_tagger_model():
    model_dir = "wd14_tagger_model"
    repo_id = DEFAULT_WD14_TAGGER_REPO

    if not os.path.exists(model_dir):
        print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
        for file in FILES:
            hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
        for file in SUB_DIR_FILES:
            hf_hub_download(
                repo_id,
                file,
                subfolder=SUB_DIR,
                cache_dir=model_dir + "/" + SUB_DIR,
                force_download=True,
                force_filename=file,
            )
    else:
        print("using existing wd14 tagger model")

    # モデルを読み込む
    model = TFSMLayer(model_dir, call_endpoint='serving_default')
    return model


def generate_tags(images, model_dir, model):
    with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        l = [row for row in reader]
        header = l[0]  # tag_id,name,category,count
        rows = l[1:]
    assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"

    general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
    character_tags = [row[1] for row in rows[1:] if row[2] == "4"]

    tag_freq = {}
    undesired_tags = ['one-piece_swimsuit',
                      'swimsuit',
                      'leotard',
                      'saitama_(one-punch_man)',
                      '1boy',
    ]

    probs = model(images, training=False)
    probs = probs['predictions_sigmoid'].numpy()

    tag_text_list = []
    for prob in probs:
        combined_tags = []
        general_tag_text = ""
        character_tag_text = ""
        thresh = 0.35
        for i, p in enumerate(prob[4:]):
            if i < len(general_tags) and p >= thresh:
                tag_name = general_tags[i]
                if tag_name not in undesired_tags:
                    tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                    general_tag_text += ", " + tag_name
                    combined_tags.append(tag_name)
            elif i >= len(general_tags) and p >= thresh:
                tag_name = character_tags[i - len(general_tags)]
                if tag_name not in undesired_tags:
                    tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                    character_tag_text += ", " + tag_name
                    combined_tags.append(tag_name)

        if len(general_tag_text) > 0:
            general_tag_text = general_tag_text[2:]
        if len(character_tag_text) > 0:
            character_tag_text = character_tag_text[2:]

        tag_text = ", ".join(combined_tags)
        tag_text_list.append(tag_text)
    return tag_text_list
        

def generate_prompt_json(target_folder, prompt_file, model_dir, model):
    image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
    image_count = len(image_files)

    prompt_list = []

    for i, filename in enumerate(image_files, 1):
        source_path = "source/" + filename
        target_path = os.path.join(target_folder, filename)  # Use absolute path
        target_path2 = "target/" + filename

        prompt = generate_tags(target_path, model_dir, model)

        for j in range(4):
            prompt_data = {
                "source": f"{source_path.split('.')[0]}_{j}.jpg",
                "target": f"{target_path2.split('.')[0]}_{j}.jpg",
                "prompt": prompt
            }

            prompt_list.append(prompt_data)

        print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)

    with open(prompt_file, "w") as file:
        for prompt_data in prompt_list:
            json.dump(prompt_data, file)
            file.write("\n")

    print(f"Processing completed. Total Images: {image_count}")


if __name__ == '__main__':
    model_dir = "wd14_tagger_model"
    model = load_wd14_tagger_model()
    prompt = generate_tags(target_path, model_dir, model)