Spaces:
Running
on
Zero
Running
on
Zero
# app.py — InstantID × Beautiful Realistic Asians v7(ZeroGPU 対応・CrucibleAI ControlNet) | |
# 2025-06-22 | |
############################################################################## | |
# 0. 旧 API → 新 API 互換パッチ(必ず diffusers import の前に置く) | |
############################################################################## | |
from huggingface_hub import hf_hub_download | |
import huggingface_hub as _hf_hub | |
# diffusers-0.27 は cached_download() を呼び出すため、v0.28+ でも使えるよう注入 | |
if not hasattr(_hf_hub, "cached_download"): | |
_hf_hub.cached_download = hf_hub_download | |
############################################################################## | |
# 1. 標準 & 外部ライブラリ | |
############################################################################## | |
import os, io, base64, subprocess, traceback | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import torch | |
import gradio as gr | |
import spaces | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from PIL import Image | |
from diffusers import ( | |
StableDiffusionControlNetPipeline, | |
ControlNetModel, | |
DPMSolverMultistepScheduler, | |
AutoencoderKL, | |
) | |
from diffusers.loaders import AttnProcsLayers | |
from insightface.app import FaceAnalysis | |
from basicsr.utils.download_util import load_file_from_url | |
from realesrgan import RealESRGANer | |
############################################################################## | |
# 2. キャッシュ & パス | |
############################################################################## | |
PERSIST_BASE = Path("/data") | |
CACHE_ROOT = ( | |
PERSIST_BASE / "instantid_cache" | |
if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK) | |
else Path.home() / ".cache" / "instantid_cache" | |
) | |
MODELS_DIR = CACHE_ROOT / "models" | |
LORA_DIR = CACHE_ROOT / "lora" | |
UPSCALE_DIR = CACHE_ROOT / "realesrgan" | |
for _p in (MODELS_DIR, LORA_DIR, UPSCALE_DIR): | |
_p.mkdir(parents=True, exist_ok=True) | |
############################################################################## | |
# 3. モデル URL 一覧 | |
############################################################################## | |
BRA_V7_URL = ( | |
"https://huggingface.co/i0switch-assets/Beautiful_Realistic_Asians_v7/" | |
"resolve/main/beautiful_realistic_asians_v7_fp16.safetensors" | |
) | |
IP_ADAPTER_BIN_URL = ( | |
"https://huggingface.co/h94/IP-Adapter/" | |
"resolve/main/ip-adapter-plus-face_sd15.bin" | |
) | |
IP_ADAPTER_LORA_URL = ( | |
"https://huggingface.co/h94/IP-Adapter-FaceID/" | |
"resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors" | |
) | |
REALESRGAN_URL = ( | |
"https://huggingface.co/aimagelab/realesrgan/" | |
"resolve/main/RealESRGAN_x4plus.pth" | |
) | |
############################################################################## | |
# 4. ダウンローダ | |
############################################################################## | |
def download(url: str, dst: Path, attempts: int = 2): | |
if dst.exists(): | |
return dst | |
for i in range(1, attempts + 1): | |
try: | |
subprocess.check_call(["curl", "-L", "-o", str(dst), url]) | |
return dst | |
except subprocess.CalledProcessError: | |
print(f"[DL] Retry {i}/{attempts} failed: {url}") | |
load_file_from_url(url=url, model_dir=str(dst.parent), file_name=dst.name) | |
return dst | |
############################################################################## | |
# 5. グローバル変数(lazy-load) | |
############################################################################## | |
pipe: Optional[StableDiffusionControlNetPipeline] = None | |
face_analyser: Optional[FaceAnalysis] = None | |
upsampler: Optional[RealESRGANer] = None | |
############################################################################## | |
# 6. パイプライン初期化 | |
############################################################################## | |
def initialize_pipelines(): | |
global pipe, face_analyser, upsampler | |
if pipe is not None: | |
return | |
print("[INIT] Downloading model assets …") | |
# 6-1 ベースモデル & IP-Adapter | |
bra_ckpt = download(BRA_V7_URL, MODELS_DIR / "bra_v7.safetensors") | |
ip_bin = download(IP_ADAPTER_BIN_URL, MODELS_DIR / "ip_adapter.bin") | |
ip_lora = download(IP_ADAPTER_LORA_URL, LORA_DIR / "ip_adapter_faceid.lora") | |
# 6-2 ControlNet(CrucibleAI / diffusion_sd15) | |
controlnet = ControlNetModel.from_pretrained( | |
"CrucibleAI/ControlNetMediaPipeFace", # 公開リポジトリ :contentReference[oaicite:0]{index=0} | |
subfolder="diffusion_sd15", # SD-1.5 用フォルダ :contentReference[oaicite:1]{index=1} | |
torch_dtype=torch.float16, | |
cache_dir=str(MODELS_DIR), | |
) | |
# 6-3 Diffusers パイプライン | |
pipe_tmp = StableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
controlnet=controlnet, | |
vae=AutoencoderKL.from_pretrained( | |
"stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16 | |
), | |
torch_dtype=torch.float16, | |
cache_dir=str(MODELS_DIR), | |
safety_checker=None, | |
) | |
pipe_tmp.scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
subfolder="scheduler", | |
cache_dir=str(MODELS_DIR), | |
) | |
# 6-4 **IP-Adapter 読み込み(API 仕様は positional 三つ)** | |
# diffusers-0.27.2 では subfolder / weight_name が必須 :contentReference[oaicite:2]{index=2} | |
ip_dir = ip_bin.parent | |
pipe_tmp.load_ip_adapter( | |
str(ip_dir), # path or repo id | |
"", # subfolder(ファイル直下なので空文字列) | |
ip_bin.name # weight_name | |
) | |
# IP-Adapter の追加 LoRA を合流 | |
AttnProcsLayers(pipe_tmp.unet.attn_processors).load_lora_weights( | |
ip_lora, adapter_name="ip_faceid", safe_load=True | |
) | |
pipe_tmp.set_adapters(["ip_faceid"], adapter_weights=[0.6]) | |
pipe_tmp.to("cuda") | |
pipe = pipe_tmp | |
# 6-5 InsightFace | |
face_analyser = FaceAnalysis( | |
name="buffalo_l", root=str(MODELS_DIR), providers=["CUDAExecutionProvider"] | |
) | |
face_analyser.prepare(ctx_id=0, det_size=(640, 640)) | |
# 6-6 Real-ESRGAN | |
esrgan_ckpt = download(REALESRGAN_URL, UPSCALE_DIR / "realesrgan_x4plus.pth") | |
upsampler = RealESRGANer( | |
scale=4, | |
model_path=str(esrgan_ckpt), | |
half=True, | |
tile=512, | |
tile_pad=10, | |
pre_pad=0, | |
gpu_id=0, | |
) | |
print("[INIT] Pipelines ready.") | |
############################################################################## | |
# 7. プロンプトテンプレ | |
############################################################################## | |
BASE_PROMPT = ( | |
"(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k, " | |
"cinematic lighting, textured skin, " | |
) | |
NEG_PROMPT = ( | |
"verybadimagenegative_v1.3, ng_deepnegative_v1_75t, " | |
"(worst quality:2), (low quality:2), lowres, blurry, bad anatomy, " | |
"bad hands, extra digits, watermark, signature" | |
) | |
############################################################################## | |
# 8. 生成関数(GPU アタッチ) | |
############################################################################## | |
def generate_core( | |
face_img: Image.Image, | |
subject: str, | |
add_prompt: str = "", | |
add_neg: str = "", | |
cfg: float = 7.5, | |
ip_scale: float = 0.6, | |
steps: int = 30, | |
w: int = 768, | |
h: int = 768, | |
upscale: bool = False, | |
up_factor: int = 4, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), | |
): | |
try: | |
if pipe is None: | |
initialize_pipelines() | |
faces = face_analyser.get(np.array(face_img)) | |
if len(faces) == 0: | |
raise ValueError("顔が検出できませんでした。別の画像をお試しください。") | |
pipe.set_adapters(["ip_faceid"], adapter_weights=[ip_scale]) | |
prompt = BASE_PROMPT + subject + ", " + add_prompt | |
negative = NEG_PROMPT + ", " + add_neg | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative, | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
image=face_img, | |
control_image=None, | |
width=int(w), | |
height=int(h), | |
).images[0] | |
if upscale and upsampler is not None: | |
upsampler.scale = 4 if up_factor == 4 else 8 | |
result, _ = upsampler.enhance(np.array(result)) | |
result = Image.fromarray(result) | |
return result | |
except Exception as e: | |
traceback.print_exc() | |
raise e | |
############################################################################## | |
# 9. Gradio UI | |
############################################################################## | |
with gr.Blocks(title="InstantID × BRA v7 (ZeroGPU)") as demo: | |
gr.Markdown("## InstantID × Beautiful Realistic Asians v7") | |
with gr.Row(): | |
face_img = gr.Image(type="pil", label="Face ID", sources=["upload"]) | |
subject = gr.Textbox( | |
label="被写体説明(例: 30代日本人女性、黒髪セミロング)", interactive=True | |
) | |
add_prompt = gr.Textbox(label="追加プロンプト", interactive=True) | |
add_neg = gr.Textbox(label="追加ネガティブ", interactive=True) | |
with gr.Row(): | |
cfg = gr.Slider(1, 20, value=7.5, step=0.5, label="CFG Scale") | |
ip_scale = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IP-Adapter Weight") | |
with gr.Row(): | |
steps = gr.Slider(10, 50, value=30, step=1, label="Steps") | |
w = gr.Slider(512, 1024, value=768, step=64, label="Width") | |
h = gr.Slider(512, 1024, value=768, step=64, label="Height") | |
with gr.Row(): | |
upscale = gr.Checkbox(label="Real-ESRGAN Upscale", value=False) | |
up_factor = gr.Radio([4, 8], value=4, label="Upscale Factor") | |
run_btn = gr.Button("Generate") | |
output_img = gr.Image(type="pil", label="Result") | |
run_btn.click( | |
fn=generate_core, | |
inputs=[ | |
face_img, | |
subject, | |
add_prompt, | |
add_neg, | |
cfg, | |
ip_scale, | |
steps, | |
w, | |
h, | |
upscale, | |
up_factor, | |
], | |
outputs=output_img, | |
show_progress=True, | |
) | |
############################################################################## | |
# 10. FastAPI REST | |
############################################################################## | |
app = FastAPI() | |
async def api_generate( | |
subject: str = Form(...), | |
cfg: float = Form(7.5), | |
steps: int = Form(30), | |
ip_scale: float = Form(0.6), | |
w: int = Form(768), | |
h: int = Form(768), | |
file: UploadFile = File(...), | |
): | |
try: | |
img_bytes = await file.read() | |
pil = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
res = generate_core( | |
face_img=pil, | |
subject=subject, | |
add_prompt="", | |
add_neg="", | |
cfg=cfg, | |
ip_scale=ip_scale, | |
steps=steps, | |
w=w, | |
h=h, | |
upscale=False, | |
up_factor=4, | |
) | |
buf = io.BytesIO() | |
res.save(buf, format="PNG") | |
b64 = base64.b64encode(buf.getvalue()).decode() | |
return {"image": f"data:image/png;base64,{b64}"} | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
############################################################################## | |
# 11. Launch(Gradio が自動で Uvicorn を起動) | |
############################################################################## | |
demo.queue(default_concurrency_limit=2).launch(share=False) | |