Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,305 Bytes
c9cc441 9a157ab c9cc441 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import argparse
import csv
import os
import json
from PIL import Image
import cv2
import numpy as np
from tensorflow.keras.layers import TFSMLayer
from huggingface_hub import hf_hub_download
from pathlib import Path
# from wd14 tagger
IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def load_wd14_tagger_model():
model_dir = "wd14_tagger_model"
repo_id = DEFAULT_WD14_TAGGER_REPO
if not os.path.exists(model_dir):
print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
for file in FILES:
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
repo_id,
file,
subfolder=SUB_DIR,
cache_dir=model_dir + "/" + SUB_DIR,
force_download=True,
force_filename=file,
)
else:
print("using existing wd14 tagger model")
# モデルを読み込む
model = TFSMLayer(model_dir, call_endpoint='serving_default')
return model
def generate_tags(images, model_dir, model):
with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
tag_freq = {}
undesired_tags = ['one-piece_swimsuit',
'swimsuit',
'leotard',
'saitama_(one-punch_man)',
'1boy',
]
probs = model(images, training=False)
probs = probs['predictions_sigmoid'].numpy()
tag_text_list = []
for prob in probs:
combined_tags = []
general_tag_text = ""
character_tag_text = ""
thresh = 0.35
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= thresh:
tag_name = general_tags[i]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= thresh:
tag_name = character_tags[i - len(general_tags)]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
combined_tags.append(tag_name)
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
tag_text = ", ".join(combined_tags)
tag_text_list.append(tag_text)
return tag_text_list
def generate_prompt_json(target_folder, prompt_file, model_dir, model):
image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
image_count = len(image_files)
prompt_list = []
for i, filename in enumerate(image_files, 1):
source_path = "source/" + filename
target_path = os.path.join(target_folder, filename) # Use absolute path
target_path2 = "target/" + filename
prompt = generate_tags(target_path, model_dir, model)
for j in range(4):
prompt_data = {
"source": f"{source_path.split('.')[0]}_{j}.jpg",
"target": f"{target_path2.split('.')[0]}_{j}.jpg",
"prompt": prompt
}
prompt_list.append(prompt_data)
print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
with open(prompt_file, "w") as file:
for prompt_data in prompt_list:
json.dump(prompt_data, file)
file.write("\n")
print(f"Processing completed. Total Images: {image_count}")
if __name__ == '__main__':
model_dir = "wd14_tagger_model"
model = load_wd14_tagger_model()
prompt = generate_tags(target_path, model_dir, model) |