Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,114 Bytes
f2192e3 92aaea0 942bdcb 92aaea0 942bdcb 92aaea0 f2192e3 92aaea0 f2192e3 92aaea0 942bdcb f2192e3 942bdcb 92aaea0 942bdcb 92aaea0 942bdcb 92aaea0 942bdcb f2192e3 92aaea0 942bdcb f2192e3 942bdcb 92aaea0 f2192e3 92aaea0 f2192e3 92aaea0 f2192e3 92aaea0 f2192e3 92aaea0 f2192e3 92aaea0 f2192e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# app.py — ZeroGPU対応 + ポート自動フォールバック
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import os
import subprocess
import traceback
import base64
import io
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
##############################################################################
# 0. 設定とヘルパー
##############################################################################
PERSIST_BASE = Path("/data")
CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
and os.access(PERSIST_BASE, os.W_OK)
else Path.home() / ".cache" / "instantid_cache")
MODELS_DIR = CACHE_ROOT / "models"
LORA_DIR = CACHE_ROOT / "lora"
for d in (MODELS_DIR, LORA_DIR):
d.mkdir(parents=True, exist_ok=True)
def dl(url: str, dst: Path, attempts: int = 2):
"""冪等ダウンロード(既にあればスキップ、リトライ付き)"""
if dst.exists():
return
for i in range(1, attempts + 1):
print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0:
return
raise RuntimeError(f"download failed → {url}")
print("— Starting asset download check —")
BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin"
dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE)
LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
print("— Asset download check finished —")
##############################################################################
# 1. Gradio UI
##############################################################################
with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
with gr.Row(equal_height=True):
with gr.Column():
face_in = gr.Image(type="pil", label="顔画像 (必須)")
subj_in = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性")
add_in = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...")
addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...")
with gr.Row():
ip_sld = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight")
cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
step_sld = gr.Slider(10,50,20,step=1,label="Steps")
w_sld = gr.Slider(512,1024,512,step=64,label="幅")
h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
up_ck = gr.Checkbox(label="アップスケール",value=True)
up_fac = gr.Slider(1,8,2,step=1,label="倍率")
btn = gr.Button("生成",variant="primary")
with gr.Column():
out_img = gr.Image(label="結果")
demo.queue()
# ダミー推論(実装は省略)
def generate_ui(*args, **kwargs):
return Image.new("RGB", (512,768), (127,127,127))
btn.click(generate_ui,
inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,
w_sld,h_sld,up_ck,up_fac],
outputs=[out_img])
##############################################################################
# 2. FastAPI ラッパー(REST API)
##############################################################################
app = FastAPI()
@app.post("/api/predict")
async def predict(face: UploadFile = File(...)):
try:
img = Image.open(face.file)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode()
return {"image_base64": img_b64}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# Gradio を FastAPI にマウント
app = gr.mount_gradio_app(app, demo, path="/")
print("Application startup script finished. Waiting for requests.")
##############################################################################
# 3. Uvicorn 起動(ポート重複時フォールバック)
##############################################################################
if __name__ == "__main__":
import uvicorn
port_env = int(os.getenv("PORT", "7860"))
try:
uvicorn.run(app, host="0.0.0.0", port=port_env, workers=1, log_level="info")
except OSError as e:
if e.errno == 98 and port_env != 7860:
print(f"⚠️ Port {port_env} busy → falling back to 7860")
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info")
else:
raise
|