Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,856 Bytes
3542be4 6e3a021 3542be4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# -*- coding: utf-8 -*-
# https://github.com/kohya-ss/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py
import csv
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
from PIL import Image
import cv2
import numpy as np
from pathlib import Path
import onnx
import onnxruntime as ort
# from wd14 tagger
IMAGE_SIZE = 448
model = None # Initialize model variable
def convert_array_to_bgr(array):
"""
Convert a NumPy array image to BGR format regardless of its original format.
Parameters:
- array: NumPy array of the image.
Returns:
- A NumPy array representing the image in BGR format.
"""
# グレースケール画像(2次元配列)
if array.ndim == 2:
# グレースケールをBGRに変換(3チャンネルに拡張)
bgr_array = np.stack((array,) * 3, axis=-1)
# RGBAまたはRGB画像(3次元配列)
elif array.ndim == 3:
# RGBA画像の場合、アルファチャンネルを削除
if array.shape[2] == 4:
array = array[:, :, :3]
# RGBをBGRに変換
bgr_array = array[:, :, ::-1]
else:
raise ValueError("Unsupported array shape.")
return bgr_array
def preprocess_image(image):
image = np.array(image)
image = convert_array_to_bgr(image)
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def modelLoad(model_dir):
onnx_path = os.path.join(model_dir, "model.onnx")
# 実行プロバイダーをCPUのみに指定
providers = ['CPUExecutionProvider']
# InferenceSessionの作成時にプロバイダーのリストを指定
ort_session = ort.InferenceSession(onnx_path, providers=providers)
input_name = ort_session.get_inputs()[0].name
# 実際に使用されているプロバイダーを取得して表示
actual_provider = ort_session.get_providers()[0] # 使用されているプロバイダー
print(f"Using provider: {actual_provider}")
return [ort_session, input_name]
def analysis(image_path, model_dir, model):
ort_session = model[0]
input_name = model[1]
with open(os.path.join(model_dir, "selected_tags.csv"), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
tag_freq = {}
undesired_tags = ["transparent background"]
image_pil = Image.open(image_path)
image_preprocessed = preprocess_image(image_pil)
image_preprocessed = np.expand_dims(image_preprocessed, axis=0)
# 推論を実行
prob = ort_session.run(None, {input_name: image_preprocessed})[0][0]
# タグを生成
combined_tags = []
general_tag_text = ""
character_tag_text = ""
remove_underscore = True
caption_separator = ", "
general_threshold = 0.35
character_threshold = 0.35
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= general_threshold:
tag_name = general_tags[i]
if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= character_threshold:
tag_name = character_tags[i - len(general_tags)]
if remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[len(caption_separator) :]
tag_text = caption_separator.join(combined_tags)
return tag_text |