Spaces:
Running
on
Zero
Running
on
Zero
refactor
Browse files- scripts/generate_prompt.py +48 -119
- scripts/process_utils.py +3 -5
scripts/generate_prompt.py
CHANGED
@@ -10,145 +10,74 @@ from tensorflow.keras.layers import TFSMLayer
|
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from pathlib import Path
|
12 |
|
13 |
-
#
|
14 |
IMAGE_SIZE = 448
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
CSV_FILE =
|
22 |
|
23 |
def preprocess_image(image):
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
pad_y = size - image.shape[0]
|
31 |
-
pad_l = pad_x // 2
|
32 |
-
pad_t = pad_y // 2
|
33 |
-
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
34 |
|
35 |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
36 |
-
|
37 |
-
|
38 |
-
image = image.astype(np.float32)
|
39 |
-
return image
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def load_wd14_tagger_model():
|
|
|
43 |
model_dir = "wd14_tagger_model"
|
44 |
-
repo_id = DEFAULT_WD14_TAGGER_REPO
|
45 |
-
|
46 |
if not os.path.exists(model_dir):
|
47 |
-
|
48 |
-
for file in FILES:
|
49 |
-
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
|
50 |
-
for file in SUB_DIR_FILES:
|
51 |
-
hf_hub_download(
|
52 |
-
repo_id,
|
53 |
-
file,
|
54 |
-
subfolder=SUB_DIR,
|
55 |
-
cache_dir=model_dir + "/" + SUB_DIR,
|
56 |
-
force_download=True,
|
57 |
-
force_filename=file,
|
58 |
-
)
|
59 |
else:
|
60 |
-
print("
|
61 |
-
|
62 |
-
# モデルを読み込む
|
63 |
-
model = TFSMLayer(model_dir, call_endpoint='serving_default')
|
64 |
-
return model
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
with open(
|
69 |
reader = csv.reader(f)
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
assert header[
|
74 |
-
|
75 |
-
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
76 |
-
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
tag_freq = {}
|
79 |
-
undesired_tags =
|
80 |
-
'swimsuit',
|
81 |
-
'leotard',
|
82 |
-
'saitama_(one-punch_man)',
|
83 |
-
'1boy',
|
84 |
-
]
|
85 |
-
|
86 |
-
probs = model(images, training=False)
|
87 |
-
probs = probs['predictions_sigmoid'].numpy()
|
88 |
|
|
|
89 |
tag_text_list = []
|
|
|
90 |
for prob in probs:
|
91 |
-
|
92 |
-
general_tag_text = ""
|
93 |
-
character_tag_text = ""
|
94 |
-
thresh = 0.35
|
95 |
for i, p in enumerate(prob[4:]):
|
96 |
-
if i < len(general_tags)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
combined_tags.append(tag_name)
|
102 |
-
elif i >= len(general_tags) and p >= thresh:
|
103 |
-
tag_name = character_tags[i - len(general_tags)]
|
104 |
-
if tag_name not in undesired_tags:
|
105 |
-
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
106 |
-
character_tag_text += ", " + tag_name
|
107 |
-
combined_tags.append(tag_name)
|
108 |
|
109 |
-
|
110 |
-
general_tag_text = general_tag_text[2:]
|
111 |
-
if len(character_tag_text) > 0:
|
112 |
-
character_tag_text = character_tag_text[2:]
|
113 |
-
|
114 |
-
tag_text = ", ".join(combined_tags)
|
115 |
-
tag_text_list.append(tag_text)
|
116 |
return tag_text_list
|
117 |
-
|
118 |
-
|
119 |
-
def generate_prompt_json(target_folder, prompt_file, model_dir, model):
|
120 |
-
image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
|
121 |
-
image_count = len(image_files)
|
122 |
-
|
123 |
-
prompt_list = []
|
124 |
-
|
125 |
-
for i, filename in enumerate(image_files, 1):
|
126 |
-
source_path = "source/" + filename
|
127 |
-
target_path = os.path.join(target_folder, filename) # Use absolute path
|
128 |
-
target_path2 = "target/" + filename
|
129 |
-
|
130 |
-
prompt = generate_tags(target_path, model_dir, model)
|
131 |
-
|
132 |
-
for j in range(4):
|
133 |
-
prompt_data = {
|
134 |
-
"source": f"{source_path.split('.')[0]}_{j}.jpg",
|
135 |
-
"target": f"{target_path2.split('.')[0]}_{j}.jpg",
|
136 |
-
"prompt": prompt
|
137 |
-
}
|
138 |
-
|
139 |
-
prompt_list.append(prompt_data)
|
140 |
-
|
141 |
-
print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
|
142 |
-
|
143 |
-
with open(prompt_file, "w") as file:
|
144 |
-
for prompt_data in prompt_list:
|
145 |
-
json.dump(prompt_data, file)
|
146 |
-
file.write("\n")
|
147 |
-
|
148 |
-
print(f"Processing completed. Total Images: {image_count}")
|
149 |
-
|
150 |
-
|
151 |
-
if __name__ == '__main__':
|
152 |
-
model_dir = "wd14_tagger_model"
|
153 |
-
model = load_wd14_tagger_model()
|
154 |
-
prompt = generate_tags(target_path, model_dir, model)
|
|
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from pathlib import Path
|
12 |
|
13 |
+
# 画像サイズの設定
|
14 |
IMAGE_SIZE = 448
|
15 |
|
16 |
+
# デフォルトのタグ付けリポジトリとファイル構成
|
17 |
+
DEFAULT_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
18 |
+
MODEL_FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
19 |
+
VAR_DIR = "variables"
|
20 |
+
VAR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
21 |
+
CSV_FILE = MODEL_FILES[-1]
|
22 |
|
23 |
def preprocess_image(image):
|
24 |
+
"""画像を前処理して正方形に変換"""
|
25 |
+
img = np.array(image)[:, :, ::-1] # RGB->BGR
|
26 |
|
27 |
+
size = max(img.shape[:2])
|
28 |
+
pad_x, pad_y = size - img.shape[1], size - img.shape[0]
|
29 |
+
img = np.pad(img, ((pad_y // 2, pad_y - pad_y // 2), (pad_x // 2, pad_x - pad_x // 2), (0, 0)), mode="constant", constant_values=255)
|
|
|
|
|
|
|
|
|
30 |
|
31 |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
32 |
+
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
33 |
+
return img.astype(np.float32)
|
|
|
|
|
34 |
|
35 |
+
def download_model_files(repo_id, model_dir, sub_dir, files, sub_files):
|
36 |
+
"""モデルファイルをHugging Face Hubからダウンロード"""
|
37 |
+
for file in files:
|
38 |
+
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
|
39 |
+
for file in sub_files:
|
40 |
+
hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
|
41 |
|
42 |
def load_wd14_tagger_model():
|
43 |
+
"""WD14タグ付けモデルをロード"""
|
44 |
model_dir = "wd14_tagger_model"
|
|
|
|
|
45 |
if not os.path.exists(model_dir):
|
46 |
+
download_model_files(DEFAULT_REPO, model_dir, VAR_DIR, MODEL_FILES, VAR_FILES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
else:
|
48 |
+
print("Using existing model")
|
49 |
+
return TFSMLayer(model_dir, call_endpoint='serving_default')
|
|
|
|
|
|
|
50 |
|
51 |
+
def read_tags_from_csv(csv_path):
|
52 |
+
"""CSVファイルからタグを読み取る"""
|
53 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
54 |
reader = csv.reader(f)
|
55 |
+
tags = [row for row in reader]
|
56 |
+
header = tags[0]
|
57 |
+
rows = tags[1:]
|
58 |
+
assert header[:3] == ["tag_id", "name", "category"], f"Unexpected CSV format: {header}"
|
59 |
+
return rows
|
|
|
|
|
60 |
|
61 |
+
def generate_tags(images, model_dir, model):
|
62 |
+
"""画像にタグを生成"""
|
63 |
+
rows = read_tags_from_csv(os.path.join(model_dir, CSV_FILE))
|
64 |
+
general_tags = [row[1] for row in rows if row[2] == "0"]
|
65 |
+
character_tags = [row[1] for row in rows if row[2] == "4"]
|
66 |
+
|
67 |
tag_freq = {}
|
68 |
+
undesired_tags = {'one-piece_swimsuit', 'swimsuit', 'leotard', 'saitama_(one-punch_man)', '1boy'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
probs = model(images, training=False)['predictions_sigmoid'].numpy()
|
71 |
tag_text_list = []
|
72 |
+
|
73 |
for prob in probs:
|
74 |
+
tags_combined = []
|
|
|
|
|
|
|
75 |
for i, p in enumerate(prob[4:]):
|
76 |
+
tag_list = general_tags if i < len(general_tags) else character_tags
|
77 |
+
tag = tag_list[i - len(general_tags)] if i >= len(general_tags) else tag_list[i]
|
78 |
+
if p >= 0.35 and tag not in undesired_tags:
|
79 |
+
tag_freq[tag] = tag_freq.get(tag, 0) + 1
|
80 |
+
tags_combined.append(tag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
tag_text_list.append(", ".join(tags_combined))
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return tag_text_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/process_utils.py
CHANGED
@@ -40,9 +40,9 @@ def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
|
|
40 |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
41 |
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
42 |
use_local = _use_local
|
43 |
-
|
44 |
-
print(f"
|
45 |
-
|
46 |
init_model(use_local)
|
47 |
model = load_wd14_tagger_model()
|
48 |
sotai_gen_pipe = initialize_sotai_model()
|
@@ -59,7 +59,6 @@ def initialize_sotai_model():
|
|
59 |
controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
|
60 |
# controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
61 |
controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
62 |
-
print(use_local, controlnet_path1)
|
63 |
|
64 |
# Load the Stable Diffusion model
|
65 |
sd_pipe = StableDiffusionPipeline.from_single_file(
|
@@ -294,7 +293,6 @@ def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float =
|
|
294 |
image_np = np.array(ensure_rgb(input_image))
|
295 |
prompt = get_wd_tags([image_np])[0]
|
296 |
prompt = f"{prompt}"
|
297 |
-
print(prompt)
|
298 |
|
299 |
refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
|
300 |
refined_image = refined_image.convert('RGB')
|
|
|
40 |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
41 |
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
42 |
use_local = _use_local
|
43 |
+
|
44 |
+
print(f"\nDevice: {device}, Local model: {_use_local}\n")
|
45 |
+
|
46 |
init_model(use_local)
|
47 |
model = load_wd14_tagger_model()
|
48 |
sotai_gen_pipe = initialize_sotai_model()
|
|
|
59 |
controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
|
60 |
# controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
61 |
controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
|
|
62 |
|
63 |
# Load the Stable Diffusion model
|
64 |
sd_pipe = StableDiffusionPipeline.from_single_file(
|
|
|
293 |
image_np = np.array(ensure_rgb(input_image))
|
294 |
prompt = get_wd_tags([image_np])[0]
|
295 |
prompt = f"{prompt}"
|
|
|
296 |
|
297 |
refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
|
298 |
refined_image = refined_image.convert('RGB')
|