Spaces:
Running
Running
# app.py — ZeroGPU対応版 | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import os | |
import subprocess | |
import traceback | |
import base64 | |
import io | |
from pathlib import Path | |
# FastAPI関連(ハイブリッド構成のため維持) | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
# グローバル変数としてパイプラインを定義(初期値はNone) | |
pipe = None | |
face_app = None | |
upsampler = None | |
UPSCALE_OK = False | |
# 0. Cache dir & helpers (起動時に実行) | |
PERSIST_BASE = Path("/data") | |
CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK) | |
else Path.home() / ".cache" / "instantid_cache") | |
MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR = CACHE_ROOT/"models", CACHE_ROOT/"models"/"Lora", CACHE_ROOT/"embeddings", CACHE_ROOT/"realesrgan" | |
for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR): | |
p.mkdir(parents=True, exist_ok=True) | |
def dl(url: str, dst: Path, attempts: int = 2): | |
if dst.exists(): return | |
for i in range(1, attempts + 1): | |
print(f"⬇ Downloading {dst.name} (try {i}/{attempts})") | |
if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return | |
raise RuntimeError(f"download failed → {url}") | |
# 1. Asset download (起動時に実行) | |
print("— Starting asset download check —") | |
BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors" | |
dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT) | |
IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin" | |
dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE) | |
LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors" | |
dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE) | |
print("— Asset download check finished —") | |
# 2. パイプライン初期化関数 (GPU確保後に呼び出される) | |
def initialize_pipelines(): | |
global pipe, face_app, upsampler, UPSCALE_OK | |
# torch/diffusers/onnxruntimeなどのインポートを関数内に移動 | |
from diffusers import StableDiffusionPipeline, ControlNetModel, DPMSolverMultistepScheduler, AutoencoderKL | |
from insightface.app import FaceAnalysis | |
print("--- Initializing Pipelines (GPU is now available) ---") | |
device = torch.device("cuda") # ZeroGPUではGPUが保証されている | |
dtype = torch.float16 | |
# FaceAnalysis | |
if face_app is None: | |
print("Initializing FaceAnalysis...") | |
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers) | |
face_app.prepare(ctx_id=0, det_size=(640, 640)) | |
print("FaceAnalysis initialized.") | |
# Main Pipeline | |
if pipe is None: | |
print("Loading ControlNet...") | |
controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype) | |
print("Loading StableDiffusionPipeline...") | |
pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2) | |
print("Moving pipeline to GPU...") | |
pipe.to(device) # .to(device)をここで呼ぶ | |
print("Loading VAE...") | |
pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device) | |
pipe.controlnet = controlnet | |
print("Configuring Scheduler...") | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++") | |
print("Loading IP-Adapter and LoRA...") | |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=IP_BIN_FILE.name) | |
pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name) | |
pipe.set_ip_adapter_scale(0.65) | |
print("Main pipeline initialized.") | |
# Upscaler | |
if upsampler is None and not UPSCALE_OK: # 一度失敗したら再試行しない | |
print("Checking for Upscaler...") | |
try: | |
# cv2のインポートをここに追加 | |
import cv2 | |
from basicsr.archs.rrdb_arch import RRDBNet | |
from realesrgan import RealESRGAN | |
rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8) | |
upsampler = RealESRGAN(device, rrdb, scale=8) | |
upsampler.load_weights(str(UPSCALE_DIR / "RealESRGAN_x8plus.pth")) | |
UPSCALE_OK = True | |
print("Upscaler initialized successfully.") | |
except Exception as e: | |
UPSCALE_OK = False # 失敗を記録 | |
print(f"Real-ESRGAN disabled → {e}") | |
print("--- All pipelines ready ---") | |
# 4. Core generation logic | |
BASE_PROMPT = ("(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n""photo of {subject},\n""cinematic lighting, golden hour, rim light, shallow depth of field,\n""textured skin, high detail, shot on Canon EOS R5, 85 mm f/1.4, ISO 200,\n""<lora:ip-adapter-faceid-plusv2_sd15_lora:0.65>, (face),\n""(aesthetic:1.1), (cinematic:0.8)") | |
NEG_PROMPT = ("ng_deepnegative_v1_75t, CyberRealistic_Negative-neg, UnrealisticDream, ""(worst quality:2), (low quality:1.8), lowres, (jpeg artifacts:1.2), ""painting, sketch, illustration, drawing, cartoon, anime, cgi, render, 3d, ""monochrome, grayscale, text, logo, watermark, signature, username, ""(MajicNegative_V2:0.8), bad hands, extra digits, fused fingers, malformed limbs, ""missing arms, missing legs, (badhandv4:0.7), BadNegAnatomyV1-neg, skin blemishes, acnes, age spot, glans") | |
# 【変更点①】内部的な画像生成関数。@spaces.GPUデコレータを外す | |
def _generate_internal(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)): | |
# 初回呼び出し時にパイプラインを初期化 | |
initialize_pipelines() | |
progress(0, desc="Generating image...") | |
prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman")) | |
if add_prompt: prompt += ", " + add_prompt | |
neg = NEG_PROMPT + (", " + add_neg if add_neg else "") | |
pipe.set_ip_adapter_scale(ip_scale) | |
result = pipe(prompt=prompt, negative_prompt=neg, ip_adapter_image=face_img, image=face_img, controlnet_conditioning_scale=0.9, num_inference_steps=int(steps) + 5, guidance_scale=cfg, width=int(w), height=int(h)).images[0] | |
if upscale and UPSCALE_OK: | |
# cv2のインポートをここにも追加 | |
import cv2 | |
progress(0.8, desc="Upscaling...") | |
up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor) | |
result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB)) | |
return result | |
# 【変更点②】@spaces.GPUデコレータを持つ新しいラッパー関数を定義 | |
def generate_gpu_wrapper(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)): | |
""" | |
Hugging Face SpacesプラットフォームにGPUを要求するためのラッパー関数。 | |
実際の処理は _generate_internal を呼び出して実行する。 | |
""" | |
return _generate_internal(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress) | |
# 【変更点③】GradioのUIから新しいラッパー関数を呼び出すように変更 | |
def generate_ui(face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)): | |
if face_np is None: raise gr.Error("顔画像をアップロードしてください。") | |
# NumPy配列をPillow画像に変換 | |
face_img = Image.fromarray(face_np) | |
# _generate_coreの代わりにgenerate_gpu_wrapperを呼び出す | |
return generate_gpu_wrapper(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress) | |
# 5. Gradio UI Definition | |
with gr.Blocks() as demo: | |
gr.Markdown("# InstantID – Beautiful Realistic Asians v7 (ZeroGPU)") | |
with gr.Row(): | |
with gr.Column(): | |
face_in = gr.Image(label="顔写真",type="numpy") | |
subj_in = gr.Textbox(label="被写体説明",placeholder="e.g. woman in black suit, smiling") | |
add_in = gr.Textbox(label="追加プロンプト") | |
addneg_in = gr.Textbox(label="追加ネガティブ") | |
with gr.Accordion("詳細設定", open=False): | |
ip_sld = gr.Slider(0,1.5,0.65,step=0.05,label="IP‑Adapter scale") | |
cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG") | |
step_sld = gr.Slider(10,50,20,step=1,label="Steps") | |
w_sld = gr.Slider(512,1024,512,step=64,label="幅") | |
h_sld = gr.Slider(512,1024,768,step=64,label="高さ") | |
up_ck = gr.Checkbox(label="アップスケール",value=True) | |
up_fac = gr.Slider(1,8,2,step=1,label="倍率") | |
btn = gr.Button("生成",variant="primary") | |
with gr.Column(): | |
out_img = gr.Image(label="結果") | |
demo.queue() | |
btn.click( | |
fn=generate_ui, | |
inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac], | |
outputs=out_img | |
) | |
# 6. FastAPI Mounting | |
app = FastAPI() | |
# 【変更点④】FastAPIのエンドポイントも新しいラッパー関数を呼び出すように変更 | |
async def predict_endpoint( | |
face_image: UploadFile = File(...), | |
subject: str = Form("a woman"), | |
add_prompt: str = Form(""), | |
add_neg: str = Form(""), | |
cfg: float = Form(6.0), | |
ip_scale: float = Form(0.65), | |
steps: int = Form(20), | |
w: int = Form(512), | |
h: int = Form(768), | |
upscale: bool = Form(True), | |
up_factor: float = Form(2.0) | |
): | |
try: | |
contents = await face_image.read() | |
pil_image = Image.open(io.BytesIO(contents)) | |
# _generate_coreの代わりにgenerate_gpu_wrapperを呼び出す | |
result_pil_image = generate_gpu_wrapper( | |
pil_image, subject, add_prompt, add_neg, cfg, ip_scale, | |
steps, w, h, upscale, up_factor | |
) | |
buffered = io.BytesIO() | |
result_pil_image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return {"image_base64": img_str} | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
# GradioアプリをFastAPIアプリにマウント | |
app = gr.mount_gradio_app(app, demo, path="/") | |
print("Application startup script finished. Waiting for requests.") | |
if __name__ == "__main__": | |
import os, time, socket, uvicorn | |
def port_is_free(port: int) -> bool: | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
return s.connect_ex(("0.0.0.0", port)) != 0 | |
port = int(os.getenv("PORT", 7860)) | |
# ローカルでのテスト用にタイムアウトを短縮 | |
timeout_sec = 30 | |
poll_interval = 2 | |
t0 = time.time() | |
while not port_is_free(port): | |
waited = time.time() - t0 | |
if waited >= timeout_sec: | |
raise RuntimeError(f"Port {port} is still busy after {timeout_sec}s") | |
print(f"⚠️ Port {port} busy, retrying in {poll_interval}s …") | |
time.sleep(poll_interval) | |
# Hugging Face Spaces環境ではポートの競合は起こりにくいため、ポートチェックロジックを簡略化・無効化 | |
uvicorn.run(app, host="0.0.0.0", port=port, workers=1, log_level="info") |