Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import csv | |
import os | |
import json | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from tensorflow.keras.layers import TFSMLayer | |
from huggingface_hub import hf_hub_download | |
from pathlib import Path | |
# from wd14 tagger | |
IMAGE_SIZE = 448 | |
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 | |
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] | |
SUB_DIR = "variables" | |
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] | |
CSV_FILE = FILES[-1] | |
def preprocess_image(image): | |
image = np.array(image) | |
image = image[:, :, ::-1] # RGB->BGR | |
# pad to square | |
size = max(image.shape[0:2]) | |
pad_x = size - image.shape[1] | |
pad_y = size - image.shape[0] | |
pad_l = pad_x // 2 | |
pad_t = pad_y // 2 | |
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) | |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 | |
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) | |
image = image.astype(np.float32) | |
return image | |
def load_wd14_tagger_model(): | |
model_dir = "wd14_tagger_model" | |
repo_id = DEFAULT_WD14_TAGGER_REPO | |
if not os.path.exists(model_dir): | |
print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}") | |
for file in FILES: | |
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file) | |
for file in SUB_DIR_FILES: | |
hf_hub_download( | |
repo_id, | |
file, | |
subfolder=SUB_DIR, | |
cache_dir=model_dir + "/" + SUB_DIR, | |
force_download=True, | |
force_filename=file, | |
) | |
else: | |
print("using existing wd14 tagger model") | |
# モデルを読み込む | |
model = TFSMLayer(model_dir, call_endpoint='serving_default') | |
return model | |
def generate_tags(images, model_dir, model): | |
with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
l = [row for row in reader] | |
header = l[0] # tag_id,name,category,count | |
rows = l[1:] | |
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" | |
general_tags = [row[1] for row in rows[1:] if row[2] == "0"] | |
character_tags = [row[1] for row in rows[1:] if row[2] == "4"] | |
tag_freq = {} | |
undesired_tags = ['one-piece_swimsuit', | |
'swimsuit', | |
'leotard', | |
'saitama_(one-punch_man)', | |
'1boy', | |
] | |
probs = model(images, training=False) | |
probs = probs['predictions_sigmoid'].numpy() | |
tag_text_list = [] | |
for prob in probs: | |
combined_tags = [] | |
general_tag_text = "" | |
character_tag_text = "" | |
thresh = 0.35 | |
for i, p in enumerate(prob[4:]): | |
if i < len(general_tags) and p >= thresh: | |
tag_name = general_tags[i] | |
if tag_name not in undesired_tags: | |
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 | |
general_tag_text += ", " + tag_name | |
combined_tags.append(tag_name) | |
elif i >= len(general_tags) and p >= thresh: | |
tag_name = character_tags[i - len(general_tags)] | |
if tag_name not in undesired_tags: | |
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 | |
character_tag_text += ", " + tag_name | |
combined_tags.append(tag_name) | |
if len(general_tag_text) > 0: | |
general_tag_text = general_tag_text[2:] | |
if len(character_tag_text) > 0: | |
character_tag_text = character_tag_text[2:] | |
tag_text = ", ".join(combined_tags) | |
tag_text_list.append(tag_text) | |
return tag_text_list | |
def generate_prompt_json(target_folder, prompt_file, model_dir, model): | |
image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))] | |
image_count = len(image_files) | |
prompt_list = [] | |
for i, filename in enumerate(image_files, 1): | |
source_path = "source/" + filename | |
target_path = os.path.join(target_folder, filename) # Use absolute path | |
target_path2 = "target/" + filename | |
prompt = generate_tags(target_path, model_dir, model) | |
for j in range(4): | |
prompt_data = { | |
"source": f"{source_path.split('.')[0]}_{j}.jpg", | |
"target": f"{target_path2.split('.')[0]}_{j}.jpg", | |
"prompt": prompt | |
} | |
prompt_list.append(prompt_data) | |
print(f"Processed Images: {i}/{image_count}", end="\r", flush=True) | |
with open(prompt_file, "w") as file: | |
for prompt_data in prompt_list: | |
json.dump(prompt_data, file) | |
file.write("\n") | |
print(f"Processing completed. Total Images: {image_count}") | |
if __name__ == '__main__': | |
model_dir = "wd14_tagger_model" | |
model = load_wd14_tagger_model() | |
prompt = generate_tags(target_path, model_dir, model) |