File size: 5,114 Bytes
f2192e3
92aaea0
 
 
 
 
 
 
 
 
 
 
 
 
942bdcb
 
 
92aaea0
942bdcb
 
 
 
 
 
 
92aaea0
 
f2192e3
 
 
92aaea0
 
f2192e3
 
92aaea0
 
 
 
 
 
 
 
 
 
 
942bdcb
f2192e3
942bdcb
 
 
92aaea0
942bdcb
 
 
 
 
 
 
92aaea0
942bdcb
 
 
 
92aaea0
 
 
 
942bdcb
f2192e3
 
 
 
 
 
 
92aaea0
942bdcb
f2192e3
942bdcb
92aaea0
 
 
f2192e3
92aaea0
f2192e3
92aaea0
f2192e3
 
 
92aaea0
 
 
 
f2192e3
92aaea0
 
 
 
f2192e3
 
 
92aaea0
f2192e3
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# app.py — ZeroGPU対応 + ポート自動フォールバック
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import os
import subprocess
import traceback
import base64
import io
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, HTTPException

##############################################################################
# 0. 設定とヘルパー
##############################################################################
PERSIST_BASE = Path("/data")
CACHE_ROOT   = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
                and os.access(PERSIST_BASE, os.W_OK)
                else Path.home() / ".cache" / "instantid_cache")
MODELS_DIR   = CACHE_ROOT / "models"
LORA_DIR     = CACHE_ROOT / "lora"
for d in (MODELS_DIR, LORA_DIR):
    d.mkdir(parents=True, exist_ok=True)

def dl(url: str, dst: Path, attempts: int = 2):
    """冪等ダウンロード(既にあればスキップ、リトライ付き)"""
    if dst.exists():
        return
    for i in range(1, attempts + 1):
        print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
        if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0:
            return
    raise RuntimeError(f"download failed → {url}")

print("— Starting asset download check —")
BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin"
dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE)
LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
print("— Asset download check finished —")

##############################################################################
# 1. Gradio UI
##############################################################################
with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
    with gr.Row(equal_height=True):
        with gr.Column():
            face_in   = gr.Image(type="pil", label="顔画像 (必須)")
            subj_in   = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性")
            add_in    = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...")
            addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...")
            with gr.Row():
                ip_sld   = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight")
                cfg_sld  = gr.Slider(1,15,6,step=0.5,label="CFG")
                step_sld = gr.Slider(10,50,20,step=1,label="Steps")
                w_sld    = gr.Slider(512,1024,512,step=64,label="幅")
                h_sld    = gr.Slider(512,1024,768,step=64,label="高さ")
                up_ck    = gr.Checkbox(label="アップスケール",value=True)
                up_fac   = gr.Slider(1,8,2,step=1,label="倍率")
            btn = gr.Button("生成",variant="primary")
        with gr.Column():
            out_img = gr.Image(label="結果")
    demo.queue()

    # ダミー推論(実装は省略)
    def generate_ui(*args, **kwargs):
        return Image.new("RGB", (512,768), (127,127,127))
    btn.click(generate_ui,
              inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,
                      w_sld,h_sld,up_ck,up_fac],
              outputs=[out_img])

##############################################################################
# 2. FastAPI ラッパー(REST API)
##############################################################################
app = FastAPI()

@app.post("/api/predict")
async def predict(face: UploadFile = File(...)):
    try:
        img = Image.open(face.file)
        buffered = io.BytesIO()
        img.save(buffered, format="PNG")
        img_b64 = base64.b64encode(buffered.getvalue()).decode()
        return {"image_base64": img_b64}
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))

# Gradio を FastAPI にマウント
app = gr.mount_gradio_app(app, demo, path="/")

print("Application startup script finished. Waiting for requests.")

##############################################################################
# 3. Uvicorn 起動(ポート重複時フォールバック)
##############################################################################
if __name__ == "__main__":
    import uvicorn
    port_env = int(os.getenv("PORT", "7860"))
    try:
        uvicorn.run(app, host="0.0.0.0", port=port_env, workers=1, log_level="info")
    except OSError as e:
        if e.errno == 98 and port_env != 7860:
            print(f"⚠️ Port {port_env} busy → falling back to 7860")
            uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info")
        else:
            raise