# app.py — ZeroGPU対応 + ポート自動フォールバック import gradio as gr import spaces import torch import numpy as np from PIL import Image import os import subprocess import traceback import base64 import io from pathlib import Path from fastapi import FastAPI, UploadFile, File, Form, HTTPException ############################################################################## # 0. 設定とヘルパー ############################################################################## PERSIST_BASE = Path("/data") CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK) else Path.home() / ".cache" / "instantid_cache") MODELS_DIR = CACHE_ROOT / "models" LORA_DIR = CACHE_ROOT / "lora" for d in (MODELS_DIR, LORA_DIR): d.mkdir(parents=True, exist_ok=True) def dl(url: str, dst: Path, attempts: int = 2): """冪等ダウンロード(既にあればスキップ、リトライ付き)""" if dst.exists(): return for i in range(1, attempts + 1): print(f"⬇ Downloading {dst.name} (try {i}/{attempts})") if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return raise RuntimeError(f"download failed → {url}") print("— Starting asset download check —") BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors" dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT) IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin" dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE) LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors" dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE) print("— Asset download check finished —") ############################################################################## # 1. Gradio UI ############################################################################## with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo: with gr.Row(equal_height=True): with gr.Column(): face_in = gr.Image(type="pil", label="顔画像 (必須)") subj_in = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性") add_in = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...") addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...") with gr.Row(): ip_sld = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight") cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG") step_sld = gr.Slider(10,50,20,step=1,label="Steps") w_sld = gr.Slider(512,1024,512,step=64,label="幅") h_sld = gr.Slider(512,1024,768,step=64,label="高さ") up_ck = gr.Checkbox(label="アップスケール",value=True) up_fac = gr.Slider(1,8,2,step=1,label="倍率") btn = gr.Button("生成",variant="primary") with gr.Column(): out_img = gr.Image(label="結果") demo.queue() # ダミー推論(実装は省略) def generate_ui(*args, **kwargs): return Image.new("RGB", (512,768), (127,127,127)) btn.click(generate_ui, inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld, w_sld,h_sld,up_ck,up_fac], outputs=[out_img]) ############################################################################## # 2. FastAPI ラッパー(REST API) ############################################################################## app = FastAPI() @app.post("/api/predict") async def predict(face: UploadFile = File(...)): try: img = Image.open(face.file) buffered = io.BytesIO() img.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode() return {"image_base64": img_b64} except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) # Gradio を FastAPI にマウント app = gr.mount_gradio_app(app, demo, path="/") print("Application startup script finished. Waiting for requests.") ############################################################################## # 3. Uvicorn 起動(ポート重複時フォールバック) ############################################################################## if __name__ == "__main__": import uvicorn port_env = int(os.getenv("PORT", "7860")) try: uvicorn.run(app, host="0.0.0.0", port=port_env, workers=1, log_level="info") except OSError as e: if e.errno == 98 and port_env != 7860: print(f"⚠️ Port {port_env} busy → falling back to 7860") uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info") else: raise