i0switch commited on
Commit
dfbd440
·
verified ·
1 Parent(s): 81e8c25

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +218 -0
  2. requirements.txt +35 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — ZeroGPU対応版
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ import os
8
+ import subprocess
9
+ import traceback
10
+ import base64
11
+ import io
12
+ from pathlib import Path
13
+
14
+ # FastAPI関連(ハイブリッド構成のため維持)
15
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
+
17
+ # グローバル変数としてパイプラインを定義(初期値はNone)
18
+ pipe = None
19
+ face_app = None
20
+ upsampler = None
21
+ UPSCALE_OK = False
22
+
23
+ # 0. Cache dir & helpers (起動時に実行)
24
+ PERSIST_BASE = Path("/data")
25
+ CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
26
+ else Path.home() / ".cache" / "instantid_cache")
27
+ MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR = CACHE_ROOT/"models", CACHE_ROOT/"models"/"Lora", CACHE_ROOT/"embeddings", CACHE_ROOT/"realesrgan"
28
+
29
+ for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
30
+ p.mkdir(parents=True, exist_ok=True)
31
+
32
+ def dl(url: str, dst: Path, attempts: int = 2):
33
+ if dst.exists(): return
34
+ for i in range(1, attempts + 1):
35
+ print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
36
+ if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return
37
+ raise RuntimeError(f"download failed → {url}")
38
+
39
+ # 1. Asset download (起動時に実行)
40
+ print("— Starting asset download check —")
41
+ BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
42
+ dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
43
+ IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin"
44
+ dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE)
45
+ LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
46
+ dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
47
+ print("— Asset download check finished —")
48
+
49
+
50
+ # 2. パイプライン初期化関数 (GPU確保後に呼び出される)
51
+ def initialize_pipelines():
52
+ global pipe, face_app, upsampler, UPSCALE_OK
53
+
54
+ # torch/diffusers/onnxruntimeなどのインポートを関数内に移動
55
+ from diffusers import StableDiffusionPipeline, ControlNetModel, DPMSolverMultistepScheduler, AutoencoderKL
56
+ from insightface.app import FaceAnalysis
57
+
58
+ print("--- Initializing Pipelines (GPU is now available) ---")
59
+
60
+ device = torch.device("cuda") # ZeroGPUではGPUが保証されている
61
+ dtype = torch.float16
62
+
63
+ # FaceAnalysis
64
+ if face_app is None:
65
+ print("Initializing FaceAnalysis...")
66
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
67
+ face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
68
+ face_app.prepare(ctx_id=0, det_size=(640, 640))
69
+ print("FaceAnalysis initialized.")
70
+
71
+ # Main Pipeline
72
+ if pipe is None:
73
+ print("Loading ControlNet...")
74
+ controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype)
75
+
76
+ print("Loading StableDiffusionPipeline...")
77
+ pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2)
78
+
79
+ print("Moving pipeline to GPU...")
80
+ pipe.to(device) # .to(device)をここで呼ぶ
81
+
82
+ print("Loading VAE...")
83
+ pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device)
84
+ pipe.controlnet = controlnet
85
+
86
+ print("Configuring Scheduler...")
87
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
88
+
89
+ print("Loading IP-Adapter and LoRA...")
90
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=IP_BIN_FILE.name)
91
+ pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
92
+
93
+ pipe.set_ip_adapter_scale(0.65)
94
+ print("Main pipeline initialized.")
95
+
96
+ # Upscaler
97
+ if upsampler is None and not UPSCALE_OK: # 一度失敗したら再試行しない
98
+ print("Checking for Upscaler...")
99
+ try:
100
+ from basicsr.archs.rrdb_arch import RRDBNet
101
+ from realesrgan import RealESRGAN
102
+ rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8)
103
+ upsampler = RealESRGAN(device, rrdb, scale=8)
104
+ upsampler.load_weights(str(UPSCALE_DIR / "RealESRGAN_x8plus.pth"))
105
+ UPSCALE_OK = True
106
+ print("Upscaler initialized successfully.")
107
+ except Exception as e:
108
+ UPSCALE_OK = False # 失敗を記録
109
+ print(f"Real-ESRGAN disabled → {e}")
110
+
111
+ print("--- All pipelines ready ---")
112
+
113
+
114
+ # 4. Core generation logic
115
+ BASE_PROMPT = ("(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n""photo of {subject},\n""cinematic lighting, golden hour, rim light, shallow depth of field,\n""textured skin, high detail, shot on Canon EOS R5, 85 mm f/1.4, ISO 200,\n""<lora:ip-adapter-faceid-plusv2_sd15_lora:0.65>, (face),\n""(aesthetic:1.1), (cinematic:0.8)")
116
+ NEG_PROMPT = ("ng_deepnegative_v1_75t, CyberRealistic_Negative-neg, UnrealisticDream, ""(worst quality:2), (low quality:1.8), lowres, (jpeg artifacts:1.2), ""painting, sketch, illustration, drawing, cartoon, anime, cgi, render, 3d, ""monochrome, grayscale, text, logo, watermark, signature, username, ""(MajicNegative_V2:0.8), bad hands, extra digits, fused fingers, malformed limbs, ""missing arms, missing legs, (badhandv4:0.7), BadNegAnatomyV1-neg, skin blemishes, acnes, age spot, glans")
117
+
118
+
119
+ # ZeroGPUで実行される本体。durationを60秒に設定。
120
+ @spaces.GPU(duration=60)
121
+ def _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
122
+ # 初回呼び出し時にパイプラインを初期化
123
+ initialize_pipelines()
124
+
125
+ progress(0, desc="Generating image...")
126
+ prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman"))
127
+ if add_prompt: prompt += ", " + add_prompt
128
+ neg = NEG_PROMPT + (", " + add_neg if add_neg else "")
129
+ pipe.set_ip_adapter_scale(ip_scale)
130
+
131
+ result = pipe(prompt=prompt, negative_prompt=neg, ip_adapter_image=face_img, image=face_img, controlnet_conditioning_scale=0.9, num_inference_steps=int(steps) + 5, guidance_scale=cfg, width=int(w), height=int(h)).images[0]
132
+
133
+ if upscale and UPSCALE_OK:
134
+ progress(0.8, desc="Upscaling...")
135
+ up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor)
136
+ result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
137
+
138
+ return result
139
+
140
+ # GradioのUIから呼び出されるラッパー関数
141
+ def generate_ui(face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
142
+ if face_np is None: raise gr.Error("顔画像をアップロードしてください。")
143
+ # NumPy配列をPillow画像に変換
144
+ face_img = Image.fromarray(face_np)
145
+ return _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress)
146
+
147
+
148
+ # 5. Gradio UI Definition
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("# InstantID – Beautiful Realistic Asians v7 (ZeroGPU)")
151
+ with gr.Row():
152
+ with gr.Column():
153
+ face_in = gr.Image(label="顔写真",type="numpy")
154
+ subj_in = gr.Textbox(label="被写体説明",placeholder="e.g. woman in black suit, smiling")
155
+ add_in = gr.Textbox(label="追加プロンプト")
156
+ addneg_in = gr.Textbox(label="追加ネガティブ")
157
+ with gr.Accordion("詳細設定", open=False):
158
+ ip_sld = gr.Slider(0,1.5,0.65,step=0.05,label="IP‑Adapter scale")
159
+ cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
160
+ step_sld = gr.Slider(10,50,20,step=1,label="Steps")
161
+ w_sld = gr.Slider(512,1024,512,step=64,label="幅")
162
+ h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
163
+ up_ck = gr.Checkbox(label="アップスケール",value=True)
164
+ up_fac = gr.Slider(1,8,2,step=1,label="倍率")
165
+ btn = gr.Button("生成",variant="primary")
166
+ with gr.Column():
167
+ out_img = gr.Image(label="結果")
168
+
169
+ # .queue() はGradioの通常機能として必要
170
+ demo.queue()
171
+
172
+ btn.click(
173
+ fn=generate_ui,
174
+ inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac],
175
+ outputs=out_img
176
+ )
177
+
178
+ # 6. FastAPI Mounting
179
+ app = FastAPI()
180
+
181
+ # FastAPIのエンドポイントを定義。こちらも内部で_generate_coreを呼ぶ
182
+ @app.post("/api/predict")
183
+ async def predict_endpoint(
184
+ face_image: UploadFile = File(...),
185
+ subject: str = Form("a woman"),
186
+ add_prompt: str = Form(""),
187
+ add_neg: str = Form(""),
188
+ cfg: float = Form(6.0),
189
+ ip_scale: float = Form(0.65),
190
+ steps: int = Form(20),
191
+ w: int = Form(512),
192
+ h: int = Form(768),
193
+ upscale: bool = Form(True),
194
+ up_factor: float = Form(2.0)
195
+ ):
196
+ try:
197
+ contents = await face_image.read()
198
+ pil_image = Image.open(io.BytesIO(contents))
199
+
200
+ # FastAPI経由の呼び出しも同じコア関数を利用
201
+ result_pil_image = _generate_core(
202
+ pil_image, subject, add_prompt, add_neg, cfg, ip_scale,
203
+ steps, w, h, upscale, up_factor
204
+ )
205
+
206
+ buffered = io.BytesIO()
207
+ result_pil_image.save(buffered, format="PNG")
208
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
209
+
210
+ return {"image_base64": img_str}
211
+ except Exception as e:
212
+ traceback.print_exc()
213
+ raise HTTPException(status_code=500, detail=str(e))
214
+
215
+ # GradioアプリをFastAPIアプリにマウント
216
+ app = gr.mount_gradio_app(app, demo, path="/")
217
+
218
+ print("Application startup script finished. Waiting for requests.")
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+
3
+ # Gradio (メインのアプリとして必須)
4
+ gradio[oauth,mcp]==5.34.1
5
+ spaces==0.37.0
6
+
7
+ # モデルおよび画像処理関連
8
+ torch==2.1.2
9
+ torchvision
10
+ diffusers
11
+ numpy==1.26.4
12
+ opencv-python-headless
13
+ Pillow
14
+ insightface
15
+ basicsr
16
+ realesrgan
17
+ onnxruntime-gpu
18
+ transformers
19
+ accelerate
20
+ peft==0.11.1
21
+
22
+ # FastAPIをGradioに間借りさせるために必要
23
+ fastapi
24
+ uvicorn[standard]
25
+ python-multipart
26
+
27
+
28
+
29
+
30
+ title: InstantID Hybrid (UI + API)
31
+ emoji: 🚀
32
+ colorFrom: green
33
+ colorTo: blue
34
+ sdk: gradio
35
+ app_file: app.py