cv_test / main.py
ttttdiva's picture
Upload main.py
a7e29db verified
raw
history blame
28.9 kB
import asyncio
import base64
import datetime
import json
import logging
import os
import re
import shutil
import subprocess
import time
from typing import Optional
import requests
from bs4 import BeautifulSoup
from fake_useragent import UserAgent
from fastapi import FastAPI
from huggingface_hub import HfApi, create_repo, hf_hub_download, login
# ロギングの設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Config:
"""設定用のクラス"""
HUGGINGFACE_API_KEY = os.environ["HUGGINGFACE_API_KEY"]
CIVITAI_API_TOKEN = os.environ["CIVITAI_API_TOKEN"]
LOG_FILE = "civitai_backup.log"
LIST_FILE = "model_list.log"
REPO_IDS = {
"log": "ttttdiva/CivitAI_log_test",
"model_list": "ttttdiva/CivitAI_model_info_test",
"current": ""
}
URLS = {
"latest": "https://civitai.com/api/v1/models?sort=Newest",
"modelPage": "https://civitai.com/models/",
"modelId": "https://civitai.com/api/v1/models/",
"modelVersionId": "https://civitai.com/api/v1/model-versions/",
"hash": "https://civitai.com/api/v1/model-versions/by-hash/"
}
JST = datetime.timezone(datetime.timedelta(hours=9))
UA = UserAgent()
HEADERS = {
'Authorization': f'Bearer {CIVITAI_API_TOKEN}',
'User-Agent': 'civitai-crawler/1.0',
"Content-Type": "application/json"
}
# ===== rclone 用の追加設定 =====
RCLONE_CONF_BASE64 = os.environ.get("RCLONE_CONF_BASE64", "")
# 暗号化されたファイルが出力されるローカルディレクトリ(cryptLocal: の実体)
ENCRYPTED_DIR = "/home/user/app/encrypted"
class CivitAICrawler:
"""CivitAIからモデルをダウンロードし、Hugging Faceにアップロードするクラス"""
def __init__(self, config: Config):
self.config = config
self.api = HfApi()
self.app = FastAPI()
self.repo_ids = self.config.REPO_IDS.copy()
self.jst = self.config.JST
# rclone のセットアップ
self.setup_rclone_conf()
self.setup_routes()
def setup_routes(self):
"""FastAPIのルーティングを設定する。"""
@self.app.get("/")
def read_root():
now = str(datetime.datetime.now(self.jst))
description = f"""
CivitAIを定期的に周回し新規モデルを {self.repo_ids['current']} にバックアップするSpaceです。
model_list.log や civitai_backup.log は暗号化しないでアップロードします。
モデルのフォルダやファイルは暗号化してアップロードします。
Status: {now} + currently running :D
"""
return description
@self.app.on_event("startup")
async def startup_event():
asyncio.create_task(self.crawl())
# =============================================================================
# rclone の設定・暗号化アップロード処理
# =============================================================================
def setup_rclone_conf(self):
"""環境変数 RCLONE_CONF_BASE64 から rclone.conf を生成し、RCLONE_CONFIG 環境変数を設定"""
if not self.config.RCLONE_CONF_BASE64:
logger.warning("[WARN] RCLONE_CONF_BASE64 is empty. rclone may fail.")
return
os.makedirs(".rclone_config", exist_ok=True)
conf_path = os.path.join(".rclone_config", "rclone.conf")
with open(conf_path, "wb") as f:
f.write(base64.b64decode(self.config.RCLONE_CONF_BASE64))
os.environ["RCLONE_CONFIG"] = conf_path
logger.info(f"[INFO] rclone.conf created at: {conf_path}")
def encrypt_with_rclone(self, local_path: str):
"""
指定ファイル or ディレクトリを cryptLocal: にコピー。
フォルダ構造やファイル名を rclone の filename_encryption 設定に応じて暗号化する。
"""
if not os.path.exists(local_path):
raise FileNotFoundError(f"[ERROR] Local path not found: {local_path}")
# 事前に暗号先ディレクトリをクリーンアップ
if os.path.isdir(self.config.ENCRYPTED_DIR):
shutil.rmtree(self.config.ENCRYPTED_DIR, ignore_errors=True)
top_level_name = os.path.basename(local_path.rstrip("/"))
if not top_level_name:
top_level_name = "unnamed"
cmd = ["rclone", "copy", local_path, f"cryptLocal:{top_level_name}", "-v"]
logger.info(f"[INFO] Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
logger.info(f"[OK] rclone copy => cryptLocal:{top_level_name}")
if not os.path.isdir(self.config.ENCRYPTED_DIR):
raise FileNotFoundError(
f"[ERROR] {self.config.ENCRYPTED_DIR} not found. Check your rclone config."
)
# 例: upload_encrypted_files の中の再試行処理
def upload_encrypted_files(self, repo_id: str, base_path_in_repo: str = ""):
max_retries = 5
for root, dirs, files in os.walk(self.config.ENCRYPTED_DIR):
for fn in files:
encrypted_file_path = os.path.join(root, fn)
if not os.path.isfile(encrypted_file_path):
continue
relative_path = os.path.relpath(encrypted_file_path, self.config.ENCRYPTED_DIR)
upload_path_in_repo = os.path.join(base_path_in_repo, relative_path)
attempt = 0
while attempt < max_retries:
try:
self.api.upload_file(
path_or_fileobj=encrypted_file_path,
repo_id=repo_id,
path_in_repo=upload_path_in_repo
)
logger.info(f"[OK] Uploaded => {repo_id}/{upload_path_in_repo}")
break
except Exception as e:
attempt += 1
error_message = str(e)
# ================================
# 429によるrate-limit検出追加
# ================================
# "You have been rate-limited; you can retry this action in 31 minutes."
# のようなメッセージから時間を抽出し、その時間+1分だけ待機後、再試行
if "rate-limited" in error_message and "minutes" in error_message:
import re
match = re.search(r"in (\d+) minutes?", error_message)
if match:
minutes = int(match.group(1))
# +1分して待機
minutes += 1
logger.warning(f"Rate-limited. Waiting {minutes} minutes before retry...")
time.sleep(minutes * 60)
attempt -= 1 # 同じ attempt カウントで再試行
continue
# ================================
# すでにある1時間待機処理
# ================================
if "you can retry this action in about 1 hour" in error_message:
logger.warning("Encountered 'retry in 1 hour' error. Waiting 1 hour before retrying...")
time.sleep(3600)
attempt -= 1 # 再試行回数を増やさずにループを続ける
continue
if "over the limit of 100000 files" in error_message:
logger.warning("Repository file limit exceeded. Creating a new repository...")
self.repo_ids['current'] = self.increment_repo_name(self.repo_ids['current'])
self.api.create_repo(repo_id=self.repo_ids['current'], private=True)
attempt = 0
repo_id = self.repo_ids['current']
continue
# 上記以外のエラーの場合
if attempt < max_retries:
logger.warning(
f"Failed to upload {encrypted_file_path}, retry {attempt}/{max_retries}..."
)
else:
logger.error(
f"Failed to upload after {max_retries} attempts: {encrypted_file_path}"
)
raise
@staticmethod
def get_filename_from_cd(content_disposition: Optional[str], default_name: str) -> str:
if content_disposition:
parts = content_disposition.split(';')
for part in parts:
if "filename=" in part:
return part.split("=")[1].strip().strip('"')
return default_name
def download_file(self, url: str, destination_folder: str, default_name: str):
try:
response = requests.get(url, headers=self.config.HEADERS, stream=True)
response.raise_for_status()
except requests.RequestException as e:
logger.error(f"Failed to download file from {url}: {e}")
return
filename = self.get_filename_from_cd(response.headers.get('content-disposition'), default_name)
file_path = os.path.join(destination_folder, filename)
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
logger.info(f"Download completed: {file_path}")
def get_model_info(self, model_id: str) -> dict:
try:
response = requests.get(self.config.URLS["modelId"] + str(model_id), headers=self.config.HEADERS)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Failed to retrieve model info for ID {model_id}: {e}")
def download_model(self, model_versions: list, folder: str, existing_old_version_files: list = []):
latest_version = model_versions[0]
latest_files = latest_version["files"]
for file_info in latest_files:
download_url = file_info["downloadUrl"]
file_name = file_info["name"]
login_detected_count = 0
while login_detected_count < 5:
try:
self.download_file(download_url, folder, file_name)
except Exception as e:
logger.error(f"Exception occurred while downloading {file_name}: {e}")
login_detected_count += 1
continue
if "login" in os.listdir(folder):
login_detected_count += 1
logger.warning(f"'login' file found. Will try again. ({login_detected_count}/5)")
os.remove(os.path.join(folder, "login"))
else:
logger.info(f"Successfully downloaded {file_name}")
break
if login_detected_count >= 5:
dummy_file_name = f"{file_name}.download_failed"
dummy_file_path = os.path.join(folder, dummy_file_name)
try:
with open(dummy_file_path, "w") as f:
f.write("Download failed after 5 attempts.")
logger.error(f"Failed to download {file_name}. Created dummy file {dummy_file_name}. URL: {download_url}")
except Exception as e:
logger.error(f"Failed to create dummy file for {file_name}: {e}")
# 古いバージョンのダウンロード
if len(model_versions) > 1:
old_versions_folder = os.path.join(folder, "old_versions")
os.makedirs(old_versions_folder, exist_ok=True)
for version in model_versions[1:]:
for file_info in version["files"]:
file_name = file_info["name"]
if file_name in existing_old_version_files:
logger.info(f"Skipping download of existing old version file: {file_name}")
continue
download_url = file_info["downloadUrl"]
local_file_path = os.path.join(old_versions_folder, file_name)
login_detected_count = 0
while login_detected_count < 5:
try:
self.download_file(download_url, old_versions_folder, file_name)
except Exception as e:
logger.error(f"Exception occurred while downloading {file_name}: {e}")
login_detected_count += 1
continue
if "login" in os.listdir(old_versions_folder):
login_detected_count += 1
logger.warning(f"'login' file found while downloading {file_name}. Will try again. ({login_detected_count}/5)")
os.remove(os.path.join(old_versions_folder, "login"))
else:
logger.info(f"Successfully downloaded {file_name}")
break
if login_detected_count >= 5:
dummy_file_name = f"{file_name}.download_failed"
dummy_file_path = os.path.join(old_versions_folder, dummy_file_name)
try:
with open(dummy_file_path, "w") as f:
f.write("Download failed after 5 attempts.")
logger.error(f"Failed to download {file_name}. Created dummy file {dummy_file_name}. URL: {download_url}")
except Exception as e:
logger.error(f"Failed to create dummy file for {file_name}: {e}")
def download_images(self, model_versions: list, folder: str):
images_folder = os.path.join(folder, "images")
os.makedirs(images_folder, exist_ok=True)
images = []
for version in model_versions:
for img in version.get("images", []):
image_url = img["url"]
images.append(image_url)
for image_url in images:
image_name = image_url.split("/")[-1]
try:
response = requests.get(image_url)
response.raise_for_status()
with open(os.path.join(images_folder, f"{image_name}.png"), "wb") as file:
file.write(response.content)
except requests.RequestException as e:
logger.error(f"Error downloading image {image_url}: {e}")
def save_html_content(self, url: str, folder: str):
try:
response = requests.get(url)
response.raise_for_status()
html_path = os.path.join(folder, f"{folder}.html")
with open(html_path, 'w', encoding='utf-8') as file:
file.write(response.text)
except Exception as e:
logger.error(f"Error saving HTML content for URL {url}: {e}")
@staticmethod
def save_model_info(model_info: dict, folder: str):
with open(os.path.join(folder, "model_info.json"), "w") as file:
json.dump(model_info, file, indent=2)
@staticmethod
def increment_repo_name(repo_id: str) -> str:
match = re.search(r'(\d+)$', repo_id)
if match:
number = int(match.group(1)) + 1
return re.sub(r'\d+$', str(number), repo_id)
else:
return f"{repo_id}1"
# =============================================================================
# 暗号化しないアップロード(ログや model_list.log 用)
# =============================================================================
def upload_file_raw(
self,
file_path: str,
repo_id: Optional[str] = None,
path_in_repo: Optional[str] = None
):
if repo_id is None:
repo_id = self.repo_ids['current']
if path_in_repo is None:
path_in_repo = os.path.basename(file_path)
max_retries = 5
attempt = 0
while attempt < max_retries:
try:
self.api.upload_file(
path_or_fileobj=file_path,
repo_id=repo_id,
path_in_repo=path_in_repo
)
logger.info(f"[OK] Uploaded {file_path} => {repo_id}/{path_in_repo}")
return
except Exception as e:
attempt += 1
error_message = str(e)
if "over the limit of 100000 files" in error_message:
logger.warning("Repository file limit exceeded, creating a new repository.")
self.repo_ids['current'] = self.increment_repo_name(self.repo_ids['current'])
self.api.create_repo(repo_id=self.repo_ids['current'], private=True)
attempt = 0
repo_id = self.repo_ids['current']
continue
elif "you can retry this action in about 1 hour" in error_message:
logger.warning("Encountered 'retry in 1 hour' error. Waiting 1 hour before retrying...")
time.sleep(3600)
attempt -= 1
else:
if attempt < max_retries:
logger.warning(f"Failed to upload raw file {file_path}, retry {attempt}/{max_retries}...")
else:
logger.error(f"Failed to upload raw file after {max_retries} attempts: {file_path}")
raise
# =============================================================================
# 暗号化してアップロード (単ファイル)
# =============================================================================
def upload_file_encrypted(
self,
file_path: str,
repo_id: Optional[str] = None,
path_in_repo: Optional[str] = None
):
if repo_id is None:
repo_id = self.repo_ids['current']
base_path = path_in_repo or ""
self.encrypt_with_rclone(file_path)
self.upload_encrypted_files(repo_id=repo_id, base_path_in_repo=base_path)
if os.path.isdir(self.config.ENCRYPTED_DIR):
shutil.rmtree(self.config.ENCRYPTED_DIR, ignore_errors=True)
# =============================================================================
# 暗号化してアップロード (フォルダ)
# =============================================================================
def upload_folder_encrypted(
self,
folder_path: str,
repo_id: Optional[str] = None,
path_in_repo: Optional[str] = None
) -> str:
if repo_id is None:
repo_id = self.repo_ids['current']
base_path = path_in_repo or ""
self.encrypt_with_rclone(folder_path)
top_levels = [
d for d in os.listdir(self.config.ENCRYPTED_DIR)
if os.path.isdir(os.path.join(self.config.ENCRYPTED_DIR, d))
]
if not top_levels:
raise RuntimeError("No top-level folder found after rclone encryption.")
if len(top_levels) > 1:
logger.warning(f"Multiple top-level folders found after encryption? {top_levels}. Using the first one.")
encrypted_top_name = top_levels[0]
self.upload_encrypted_files(repo_id=repo_id, base_path_in_repo=base_path)
if os.path.isdir(self.config.ENCRYPTED_DIR):
shutil.rmtree(self.config.ENCRYPTED_DIR, ignore_errors=True)
return encrypted_top_name
# =============================================================================
# model_list.log の読み書きを「model_id: model_hf_url」で扱うよう変更
# =============================================================================
def read_model_list(self):
"""
model_list.log の各行を
"123456: https://huggingface.co/...encrypted_folder_name"
の形式で読み込み、 { "123456": "https://huggingface.co/..."} の dict を返す
"""
model_list = {}
try:
with open(self.config.LIST_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(": ", 1)
if len(parts) == 2:
stored_id, stored_url = parts
model_list[stored_id] = stored_url
return model_list
except Exception as e:
logger.error(f"Failed to read model list: {e}")
return {}
def process_model(self, model_url: str):
"""指定されたモデルURLを処理する関数。"""
try:
model_id = model_url.rstrip("/").split("/")[-1]
model_info = self.get_model_info(model_id)
latest_version = model_info.get("modelVersions", [])[0]
model_file = next(
(file for file in latest_version["files"] if file.get('type') == 'Model'),
None
)
if model_file:
latest_filename = model_file['name']
folder = os.path.splitext(latest_filename)[0]
else:
first_file = latest_version["files"][0]
latest_filename = first_file['name']
folder = os.path.splitext(latest_filename)[0]
logger.warning(f"No 'Model' type file found for model ID {model_id}. Using first file's name.")
os.makedirs(folder, exist_ok=True)
# model_list を読み込み
model_list = self.read_model_list()
# もし既に「同名(モデルページ名)がアップされている」かどうか確認したい場合の例:
# ※ 今回は modelpage_name(= model_info["name"]) をキーにするか、
# あるいは model_id (str) をキーにするか、運用に合わせて設定してください。
# 例として modelpage_name をキーとしてチェックする流れ:
modelpage_name = model_info.get("name", "Unnamed Model")
if modelpage_name in model_list.values():
# 既に同モデルページ名がアップロード済み → ここでスキップや上書きなどの処理を決定
logger.info(f"Model '{modelpage_name}' is already listed in model_list. Skipping re-upload.")
# もし「強制再アップロード」したくないなら return で処理終了:
# return
# あるいは「強制アップするがバージョンだけ追加」などいろいろ処理が可能
# ここではあえて続行するが、必要に応じて書き換えてください。
# ダウンロードや画像保存
existing_old_version_files = []
self.download_model(model_info["modelVersions"], folder, existing_old_version_files)
self.download_images(model_info["modelVersions"], folder)
self.save_html_content(model_url, folder)
self.save_model_info(model_info, folder)
# ========== rclone で暗号化フォルダをアップロード ==========
encrypted_top_name = self.upload_folder_encrypted(folder)
# 今回アップロードしたモデルの URL
model_hf_url = f"https://huggingface.co/{self.repo_ids['current']}/tree/main/{encrypted_top_name}"
# model_list.log に追記 → "modelpage_name: model_hf_url" 形式
with open(self.config.LIST_FILE, "a", encoding="utf-8") as f:
f.write(f"{modelpage_name}: {model_hf_url}\n")
# ローカルフォルダ削除
if os.path.exists(folder):
shutil.rmtree(folder)
except Exception as e:
logger.error(f"Unexpected error processing model ({model_url}): {e}")
async def crawl(self):
"""モデルを定期的にチェックし、更新を行う。"""
while True:
try:
login(token=self.config.HUGGINGFACE_API_KEY, add_to_git_credential=True)
# model_list.log & civitai_backup.log を取得
model_list_path = hf_hub_download(repo_id=self.repo_ids['model_list'], filename=self.config.LIST_FILE)
shutil.copyfile(model_list_path, f"./{self.config.LIST_FILE}")
local_file_path = hf_hub_download(repo_id=self.repo_ids["log"], filename=self.config.LOG_FILE)
shutil.copyfile(local_file_path, f"./{self.config.LOG_FILE}")
# ログ読み込み
with open(self.config.LOG_FILE, "r", encoding="utf-8") as file:
lines = file.read().splitlines()
old_models = json.loads(lines[0]) if len(lines) > 0 else []
self.repo_ids["current"] = lines[1] if len(lines) > 1 else ""
# 新着モデル確認
response = requests.get(self.config.URLS["latest"], headers=self.config.HEADERS)
response.raise_for_status()
latest_models = response.json().get("items", [])
latest_model_ids = [item.get("id") for item in latest_models if "id" in item]
# 増分チェック
new_models = list(set(latest_model_ids) - set(old_models))
if new_models:
logger.info(f"New models found: {new_models}")
model_id = new_models[0]
for attempt in range(1, 6):
try:
self.process_model(f"{self.config.URLS['modelId']}{model_id}")
break
except Exception as e:
logger.error(f"Failed to process model ID {model_id} (Attempt {attempt}/5): {e}")
if attempt == 5:
logger.error(f"Skipping model ID {model_id} after 5 failed attempts.")
else:
await asyncio.sleep(2)
else:
# 新モデルなし
with open(self.config.LOG_FILE, "w", encoding="utf-8") as f:
f.write(json.dumps(latest_model_ids) + "\n")
f.write(f"{self.repo_ids['current']}\n")
logger.info(f"Updated log file: {self.config.LOG_FILE}")
self.upload_file_raw(
file_path=self.config.LOG_FILE,
repo_id=self.repo_ids["log"],
path_in_repo=self.config.LOG_FILE
)
logger.info("Uploaded log file to repository (unencrypted).")
logger.info("No new models found.")
await asyncio.sleep(60)
continue
# 追加したモデルIDを old_models に追加
old_models.append(model_id)
# ログファイル更新
with open(self.config.LOG_FILE, "w", encoding="utf-8") as f:
f.write(json.dumps(old_models) + "\n")
f.write(f"{self.repo_ids['current']}\n")
logger.info(f"Updated log file with new model ID: {model_id}")
# ログとmodel_list.logをアップロード
self.upload_file_raw(
file_path=self.config.LOG_FILE,
repo_id=self.repo_ids["log"],
path_in_repo=self.config.LOG_FILE
)
self.upload_file_raw(
file_path=self.config.LIST_FILE,
repo_id=self.repo_ids["model_list"],
path_in_repo=self.config.LIST_FILE
)
except Exception as e:
logger.error(f"Error during crawling: {e}")
await asyncio.sleep(300)
# 実行
config = Config()
crawler = CivitAICrawler(config)
app = crawler.app