Image2Body_gradio / scripts /generate_prompt.py
yeq6x's picture
.
d02e371
raw
history blame
3.29 kB
import argparse
import csv
import os
import json
from PIL import Image
import cv2
import numpy as np
from tensorflow.keras.layers import TFSMLayer
from huggingface_hub import hf_hub_download
from pathlib import Path
import spaces
# 画像サイズの設定
IMAGE_SIZE = 448
# デフォルトのタグ付けリポジトリとファイル構成
DEFAULT_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
MODEL_FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
VAR_DIR = "variables"
VAR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = MODEL_FILES[-1]
def preprocess_image(image):
"""画像を前処理して正方形に変換"""
img = np.array(image)[:, :, ::-1] # RGB->BGR
size = max(img.shape[:2])
pad_x, pad_y = size - img.shape[1], size - img.shape[0]
img = np.pad(img, ((pad_y // 2, pad_y - pad_y // 2), (pad_x // 2, pad_x - pad_x // 2), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
return img.astype(np.float32)
def download_model_files(repo_id, model_dir, sub_dir, files, sub_files):
"""モデルファイルをHugging Face Hubからダウンロード"""
for file in files:
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
for file in sub_files:
hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
def load_wd14_tagger_model():
"""WD14タグ付けモデルをロード"""
model_dir = "wd14_tagger_model"
if not os.path.exists(model_dir):
download_model_files(DEFAULT_REPO, model_dir, VAR_DIR, MODEL_FILES, VAR_FILES)
else:
print("Using existing model")
return TFSMLayer(model_dir, call_endpoint='serving_default')
def read_tags_from_csv(csv_path):
"""CSVファイルからタグを読み取る"""
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
tags = [row for row in reader]
header = tags[0]
rows = tags[1:]
assert header[:3] == ["tag_id", "name", "category"], f"Unexpected CSV format: {header}"
return rows
def generate_tags(images, model_dir, model):
"""画像にタグを生成"""
rows = read_tags_from_csv(os.path.join(model_dir, CSV_FILE))
general_tags = [row[1] for row in rows if row[2] == "0"]
character_tags = [row[1] for row in rows if row[2] == "4"]
tag_freq = {}
undesired_tags = {'one-piece_swimsuit', 'swimsuit', 'leotard', 'saitama_(one-punch_man)', '1boy'}
probs = model(images, training=False)['predictions_sigmoid'].numpy()
tag_text_list = []
for prob in probs:
tags_combined = []
for i, p in enumerate(prob[4:]):
tag_list = general_tags if i < len(general_tags) else character_tags
tag = tag_list[i - len(general_tags)] if i >= len(general_tags) else tag_list[i]
if p >= 0.35 and tag not in undesired_tags:
tag_freq[tag] = tag_freq.get(tag, 0) + 1
tags_combined.append(tag)
tag_text_list.append(", ".join(tags_combined))
return tag_text_list