Spaces:
Running
on
Zero
Running
on
Zero
# app.py — ZeroGPU対応 + ポート自動フォールバック | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import os | |
import subprocess | |
import traceback | |
import base64 | |
import io | |
from pathlib import Path | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
############################################################################## | |
# 0. 設定とヘルパー | |
############################################################################## | |
PERSIST_BASE = Path("/data") | |
CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() | |
and os.access(PERSIST_BASE, os.W_OK) | |
else Path.home() / ".cache" / "instantid_cache") | |
MODELS_DIR = CACHE_ROOT / "models" | |
LORA_DIR = CACHE_ROOT / "lora" | |
for d in (MODELS_DIR, LORA_DIR): | |
d.mkdir(parents=True, exist_ok=True) | |
def dl(url: str, dst: Path, attempts: int = 2): | |
"""冪等ダウンロード(既にあればスキップ、リトライ付き)""" | |
if dst.exists(): | |
return | |
for i in range(1, attempts + 1): | |
print(f"⬇ Downloading {dst.name} (try {i}/{attempts})") | |
if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: | |
return | |
raise RuntimeError(f"download failed → {url}") | |
print("— Starting asset download check —") | |
BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors" | |
dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT) | |
IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin" | |
dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE) | |
LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors" | |
dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE) | |
print("— Asset download check finished —") | |
############################################################################## | |
# 1. Gradio UI | |
############################################################################## | |
with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo: | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
face_in = gr.Image(type="pil", label="顔画像 (必須)") | |
subj_in = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性") | |
add_in = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...") | |
addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...") | |
with gr.Row(): | |
ip_sld = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight") | |
cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG") | |
step_sld = gr.Slider(10,50,20,step=1,label="Steps") | |
w_sld = gr.Slider(512,1024,512,step=64,label="幅") | |
h_sld = gr.Slider(512,1024,768,step=64,label="高さ") | |
up_ck = gr.Checkbox(label="アップスケール",value=True) | |
up_fac = gr.Slider(1,8,2,step=1,label="倍率") | |
btn = gr.Button("生成",variant="primary") | |
with gr.Column(): | |
out_img = gr.Image(label="結果") | |
demo.queue() | |
# ダミー推論(実装は省略) | |
def generate_ui(*args, **kwargs): | |
return Image.new("RGB", (512,768), (127,127,127)) | |
btn.click(generate_ui, | |
inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld, | |
w_sld,h_sld,up_ck,up_fac], | |
outputs=[out_img]) | |
############################################################################## | |
# 2. FastAPI ラッパー(REST API) | |
############################################################################## | |
app = FastAPI() | |
async def predict(face: UploadFile = File(...)): | |
try: | |
img = Image.open(face.file) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_b64 = base64.b64encode(buffered.getvalue()).decode() | |
return {"image_base64": img_b64} | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Gradio を FastAPI にマウント | |
app = gr.mount_gradio_app(app, demo, path="/") | |
print("Application startup script finished. Waiting for requests.") | |
############################################################################## | |
# 3. Uvicorn 起動(ポート重複時フォールバック) | |
############################################################################## | |
if __name__ == "__main__": | |
import uvicorn | |
port_env = int(os.getenv("PORT", "7860")) | |
try: | |
uvicorn.run(app, host="0.0.0.0", port=port_env, workers=1, log_level="info") | |
except OSError as e: | |
if e.errno == 98 and port_env != 7860: | |
print(f"⚠️ Port {port_env} busy → falling back to 7860") | |
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info") | |
else: | |
raise | |