i0switch's picture
Update app.py
f2192e3 verified
raw
history blame
5.11 kB
# app.py — ZeroGPU対応 + ポート自動フォールバック
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import os
import subprocess
import traceback
import base64
import io
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
##############################################################################
# 0. 設定とヘルパー
##############################################################################
PERSIST_BASE = Path("/data")
CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
and os.access(PERSIST_BASE, os.W_OK)
else Path.home() / ".cache" / "instantid_cache")
MODELS_DIR = CACHE_ROOT / "models"
LORA_DIR = CACHE_ROOT / "lora"
for d in (MODELS_DIR, LORA_DIR):
d.mkdir(parents=True, exist_ok=True)
def dl(url: str, dst: Path, attempts: int = 2):
"""冪等ダウンロード(既にあればスキップ、リトライ付き)"""
if dst.exists():
return
for i in range(1, attempts + 1):
print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0:
return
raise RuntimeError(f"download failed → {url}")
print("— Starting asset download check —")
BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin"
dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE)
LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
print("— Asset download check finished —")
##############################################################################
# 1. Gradio UI
##############################################################################
with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
with gr.Row(equal_height=True):
with gr.Column():
face_in = gr.Image(type="pil", label="顔画像 (必須)")
subj_in = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性")
add_in = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...")
addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...")
with gr.Row():
ip_sld = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight")
cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
step_sld = gr.Slider(10,50,20,step=1,label="Steps")
w_sld = gr.Slider(512,1024,512,step=64,label="幅")
h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
up_ck = gr.Checkbox(label="アップスケール",value=True)
up_fac = gr.Slider(1,8,2,step=1,label="倍率")
btn = gr.Button("生成",variant="primary")
with gr.Column():
out_img = gr.Image(label="結果")
demo.queue()
# ダミー推論(実装は省略)
def generate_ui(*args, **kwargs):
return Image.new("RGB", (512,768), (127,127,127))
btn.click(generate_ui,
inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,
w_sld,h_sld,up_ck,up_fac],
outputs=[out_img])
##############################################################################
# 2. FastAPI ラッパー(REST API)
##############################################################################
app = FastAPI()
@app.post("/api/predict")
async def predict(face: UploadFile = File(...)):
try:
img = Image.open(face.file)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode()
return {"image_base64": img_b64}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# Gradio を FastAPI にマウント
app = gr.mount_gradio_app(app, demo, path="/")
print("Application startup script finished. Waiting for requests.")
##############################################################################
# 3. Uvicorn 起動(ポート重複時フォールバック)
##############################################################################
if __name__ == "__main__":
import uvicorn
port_env = int(os.getenv("PORT", "7860"))
try:
uvicorn.run(app, host="0.0.0.0", port=port_env, workers=1, log_level="info")
except OSError as e:
if e.errno == 98 and port_env != 7860:
print(f"⚠️ Port {port_env} busy → falling back to 7860")
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info")
else:
raise