yeq6x commited on
Commit
ce12cd7
·
1 Parent(s): 2c91007
scripts/generate_prompt.py CHANGED
@@ -10,145 +10,74 @@ from tensorflow.keras.layers import TFSMLayer
10
  from huggingface_hub import hf_hub_download
11
  from pathlib import Path
12
 
13
- # from wd14 tagger
14
  IMAGE_SIZE = 448
15
 
16
- # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
17
- DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
18
- FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
19
- SUB_DIR = "variables"
20
- SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
21
- CSV_FILE = FILES[-1]
22
 
23
  def preprocess_image(image):
24
- image = np.array(image)
25
- image = image[:, :, ::-1] # RGB->BGR
26
 
27
- # pad to square
28
- size = max(image.shape[0:2])
29
- pad_x = size - image.shape[1]
30
- pad_y = size - image.shape[0]
31
- pad_l = pad_x // 2
32
- pad_t = pad_y // 2
33
- image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
34
 
35
  interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
36
- image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
37
-
38
- image = image.astype(np.float32)
39
- return image
40
 
 
 
 
 
 
 
41
 
42
  def load_wd14_tagger_model():
 
43
  model_dir = "wd14_tagger_model"
44
- repo_id = DEFAULT_WD14_TAGGER_REPO
45
-
46
  if not os.path.exists(model_dir):
47
- print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
48
- for file in FILES:
49
- hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
50
- for file in SUB_DIR_FILES:
51
- hf_hub_download(
52
- repo_id,
53
- file,
54
- subfolder=SUB_DIR,
55
- cache_dir=model_dir + "/" + SUB_DIR,
56
- force_download=True,
57
- force_filename=file,
58
- )
59
  else:
60
- print("using existing wd14 tagger model")
61
-
62
- # モデルを読み込む
63
- model = TFSMLayer(model_dir, call_endpoint='serving_default')
64
- return model
65
 
66
-
67
- def generate_tags(images, model_dir, model):
68
- with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f:
69
  reader = csv.reader(f)
70
- l = [row for row in reader]
71
- header = l[0] # tag_id,name,category,count
72
- rows = l[1:]
73
- assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
74
-
75
- general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
76
- character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
77
 
 
 
 
 
 
 
78
  tag_freq = {}
79
- undesired_tags = ['one-piece_swimsuit',
80
- 'swimsuit',
81
- 'leotard',
82
- 'saitama_(one-punch_man)',
83
- '1boy',
84
- ]
85
-
86
- probs = model(images, training=False)
87
- probs = probs['predictions_sigmoid'].numpy()
88
 
 
89
  tag_text_list = []
 
90
  for prob in probs:
91
- combined_tags = []
92
- general_tag_text = ""
93
- character_tag_text = ""
94
- thresh = 0.35
95
  for i, p in enumerate(prob[4:]):
96
- if i < len(general_tags) and p >= thresh:
97
- tag_name = general_tags[i]
98
- if tag_name not in undesired_tags:
99
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
100
- general_tag_text += ", " + tag_name
101
- combined_tags.append(tag_name)
102
- elif i >= len(general_tags) and p >= thresh:
103
- tag_name = character_tags[i - len(general_tags)]
104
- if tag_name not in undesired_tags:
105
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
106
- character_tag_text += ", " + tag_name
107
- combined_tags.append(tag_name)
108
 
109
- if len(general_tag_text) > 0:
110
- general_tag_text = general_tag_text[2:]
111
- if len(character_tag_text) > 0:
112
- character_tag_text = character_tag_text[2:]
113
-
114
- tag_text = ", ".join(combined_tags)
115
- tag_text_list.append(tag_text)
116
  return tag_text_list
117
-
118
-
119
- def generate_prompt_json(target_folder, prompt_file, model_dir, model):
120
- image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
121
- image_count = len(image_files)
122
-
123
- prompt_list = []
124
-
125
- for i, filename in enumerate(image_files, 1):
126
- source_path = "source/" + filename
127
- target_path = os.path.join(target_folder, filename) # Use absolute path
128
- target_path2 = "target/" + filename
129
-
130
- prompt = generate_tags(target_path, model_dir, model)
131
-
132
- for j in range(4):
133
- prompt_data = {
134
- "source": f"{source_path.split('.')[0]}_{j}.jpg",
135
- "target": f"{target_path2.split('.')[0]}_{j}.jpg",
136
- "prompt": prompt
137
- }
138
-
139
- prompt_list.append(prompt_data)
140
-
141
- print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
142
-
143
- with open(prompt_file, "w") as file:
144
- for prompt_data in prompt_list:
145
- json.dump(prompt_data, file)
146
- file.write("\n")
147
-
148
- print(f"Processing completed. Total Images: {image_count}")
149
-
150
-
151
- if __name__ == '__main__':
152
- model_dir = "wd14_tagger_model"
153
- model = load_wd14_tagger_model()
154
- prompt = generate_tags(target_path, model_dir, model)
 
10
  from huggingface_hub import hf_hub_download
11
  from pathlib import Path
12
 
13
+ # 画像サイズの設定
14
  IMAGE_SIZE = 448
15
 
16
+ # デフォルトのタグ付けリポジトリとファイル構成
17
+ DEFAULT_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
18
+ MODEL_FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
19
+ VAR_DIR = "variables"
20
+ VAR_FILES = ["variables.data-00000-of-00001", "variables.index"]
21
+ CSV_FILE = MODEL_FILES[-1]
22
 
23
  def preprocess_image(image):
24
+ """画像を前処理して正方形に変換"""
25
+ img = np.array(image)[:, :, ::-1] # RGB->BGR
26
 
27
+ size = max(img.shape[:2])
28
+ pad_x, pad_y = size - img.shape[1], size - img.shape[0]
29
+ img = np.pad(img, ((pad_y // 2, pad_y - pad_y // 2), (pad_x // 2, pad_x - pad_x // 2), (0, 0)), mode="constant", constant_values=255)
 
 
 
 
30
 
31
  interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
32
+ img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
33
+ return img.astype(np.float32)
 
 
34
 
35
+ def download_model_files(repo_id, model_dir, sub_dir, files, sub_files):
36
+ """モデルファイルをHugging Face Hubからダウンロード"""
37
+ for file in files:
38
+ hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
39
+ for file in sub_files:
40
+ hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
41
 
42
  def load_wd14_tagger_model():
43
+ """WD14タグ付けモデルをロード"""
44
  model_dir = "wd14_tagger_model"
 
 
45
  if not os.path.exists(model_dir):
46
+ download_model_files(DEFAULT_REPO, model_dir, VAR_DIR, MODEL_FILES, VAR_FILES)
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
+ print("Using existing model")
49
+ return TFSMLayer(model_dir, call_endpoint='serving_default')
 
 
 
50
 
51
+ def read_tags_from_csv(csv_path):
52
+ """CSVファイルからタグを読み取る"""
53
+ with open(csv_path, "r", encoding="utf-8") as f:
54
  reader = csv.reader(f)
55
+ tags = [row for row in reader]
56
+ header = tags[0]
57
+ rows = tags[1:]
58
+ assert header[:3] == ["tag_id", "name", "category"], f"Unexpected CSV format: {header}"
59
+ return rows
 
 
60
 
61
+ def generate_tags(images, model_dir, model):
62
+ """画像にタグを生成"""
63
+ rows = read_tags_from_csv(os.path.join(model_dir, CSV_FILE))
64
+ general_tags = [row[1] for row in rows if row[2] == "0"]
65
+ character_tags = [row[1] for row in rows if row[2] == "4"]
66
+
67
  tag_freq = {}
68
+ undesired_tags = {'one-piece_swimsuit', 'swimsuit', 'leotard', 'saitama_(one-punch_man)', '1boy'}
 
 
 
 
 
 
 
 
69
 
70
+ probs = model(images, training=False)['predictions_sigmoid'].numpy()
71
  tag_text_list = []
72
+
73
  for prob in probs:
74
+ tags_combined = []
 
 
 
75
  for i, p in enumerate(prob[4:]):
76
+ tag_list = general_tags if i < len(general_tags) else character_tags
77
+ tag = tag_list[i - len(general_tags)] if i >= len(general_tags) else tag_list[i]
78
+ if p >= 0.35 and tag not in undesired_tags:
79
+ tag_freq[tag] = tag_freq.get(tag, 0) + 1
80
+ tags_combined.append(tag)
 
 
 
 
 
 
 
81
 
82
+ tag_text_list.append(", ".join(tags_combined))
 
 
 
 
 
 
83
  return tag_text_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/process_utils.py CHANGED
@@ -40,9 +40,9 @@ def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
40
  device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
41
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
42
  use_local = _use_local
43
- print('')
44
- print(f"Device: {device}, Local model: {_use_local}")
45
- print('')
46
  init_model(use_local)
47
  model = load_wd14_tagger_model()
48
  sotai_gen_pipe = initialize_sotai_model()
@@ -59,7 +59,6 @@ def initialize_sotai_model():
59
  controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
60
  # controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
61
  controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
62
- print(use_local, controlnet_path1)
63
 
64
  # Load the Stable Diffusion model
65
  sd_pipe = StableDiffusionPipeline.from_single_file(
@@ -294,7 +293,6 @@ def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float =
294
  image_np = np.array(ensure_rgb(input_image))
295
  prompt = get_wd_tags([image_np])[0]
296
  prompt = f"{prompt}"
297
- print(prompt)
298
 
299
  refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
300
  refined_image = refined_image.convert('RGB')
 
40
  device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
41
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
42
  use_local = _use_local
43
+
44
+ print(f"\nDevice: {device}, Local model: {_use_local}\n")
45
+
46
  init_model(use_local)
47
  model = load_wd14_tagger_model()
48
  sotai_gen_pipe = initialize_sotai_model()
 
59
  controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
60
  # controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
61
  controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
 
62
 
63
  # Load the Stable Diffusion model
64
  sd_pipe = StableDiffusionPipeline.from_single_file(
 
293
  image_np = np.array(ensure_rgb(input_image))
294
  prompt = get_wd_tags([image_np])[0]
295
  prompt = f"{prompt}"
 
296
 
297
  refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
298
  refined_image = refined_image.convert('RGB')