import argparse import csv import os import json from PIL import Image import cv2 import numpy as np from tensorflow.keras.layers import TFSMLayer from huggingface_hub import hf_hub_download from pathlib import Path import spaces # 画像サイズの設定 IMAGE_SIZE = 448 # デフォルトのタグ付けリポジトリとファイル構成 DEFAULT_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" MODEL_FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] VAR_DIR = "variables" VAR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = MODEL_FILES[-1] def preprocess_image(image): """画像を前処理して正方形に変換""" img = np.array(image)[:, :, ::-1] # RGB->BGR size = max(img.shape[:2]) pad_x, pad_y = size - img.shape[1], size - img.shape[0] img = np.pad(img, ((pad_y // 2, pad_y - pad_y // 2), (pad_x // 2, pad_x - pad_x // 2), (0, 0)), mode="constant", constant_values=255) interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) return img.astype(np.float32) def download_model_files(repo_id, model_dir, sub_dir, files, sub_files): """モデルファイルをHugging Face Hubからダウンロード""" for file in files: hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file) for file in sub_files: hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file) def load_wd14_tagger_model(): """WD14タグ付けモデルをロード""" model_dir = "wd14_tagger_model" if not os.path.exists(model_dir): download_model_files(DEFAULT_REPO, model_dir, VAR_DIR, MODEL_FILES, VAR_FILES) else: print("Using existing model") return TFSMLayer(model_dir, call_endpoint='serving_default') def read_tags_from_csv(csv_path): """CSVファイルからタグを読み取る""" with open(csv_path, "r", encoding="utf-8") as f: reader = csv.reader(f) tags = [row for row in reader] header = tags[0] rows = tags[1:] assert header[:3] == ["tag_id", "name", "category"], f"Unexpected CSV format: {header}" return rows def generate_tags(images, model_dir, model): """画像にタグを生成""" rows = read_tags_from_csv(os.path.join(model_dir, CSV_FILE)) general_tags = [row[1] for row in rows if row[2] == "0"] character_tags = [row[1] for row in rows if row[2] == "4"] tag_freq = {} undesired_tags = {'one-piece_swimsuit', 'swimsuit', 'leotard', 'saitama_(one-punch_man)', '1boy'} probs = model(images, training=False)['predictions_sigmoid'].numpy() tag_text_list = [] for prob in probs: tags_combined = [] for i, p in enumerate(prob[4:]): tag_list = general_tags if i < len(general_tags) else character_tags tag = tag_list[i - len(general_tags)] if i >= len(general_tags) else tag_list[i] if p >= 0.35 and tag not in undesired_tags: tag_freq[tag] = tag_freq.get(tag, 0) + 1 tags_combined.append(tag) tag_text_list.append(", ".join(tags_combined)) return tag_text_list