i0switch commited on
Commit
fad8f4c
·
verified ·
1 Parent(s): f2192e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -54
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — ZeroGPU対応 + ポート自動フォールバック
2
  import gradio as gr
3
  import spaces
4
  import torch
@@ -10,30 +10,33 @@ import traceback
10
  import base64
11
  import io
12
  from pathlib import Path
 
 
13
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
14
 
15
- ##############################################################################
16
- # 0. 設定とヘルパー
17
- ##############################################################################
 
 
 
 
18
  PERSIST_BASE = Path("/data")
19
- CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
20
- and os.access(PERSIST_BASE, os.W_OK)
21
- else Path.home() / ".cache" / "instantid_cache")
22
- MODELS_DIR = CACHE_ROOT / "models"
23
- LORA_DIR = CACHE_ROOT / "lora"
24
- for d in (MODELS_DIR, LORA_DIR):
25
- d.mkdir(parents=True, exist_ok=True)
26
 
27
  def dl(url: str, dst: Path, attempts: int = 2):
28
- """冪等ダウンロード(既にあればスキップ、リトライ付き)"""
29
- if dst.exists():
30
- return
31
  for i in range(1, attempts + 1):
32
  print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
33
- if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0:
34
- return
35
  raise RuntimeError(f"download failed → {url}")
36
 
 
37
  print("— Starting asset download check —")
38
  BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
39
  dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
@@ -43,62 +46,176 @@ LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
43
  dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
44
  print("— Asset download check finished —")
45
 
46
- ##############################################################################
47
- # 1. Gradio UI
48
- ##############################################################################
49
- with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
50
- with gr.Row(equal_height=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  with gr.Column():
52
- face_in = gr.Image(type="pil", label="顔画像 (必須)")
53
- subj_in = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性")
54
- add_in = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...")
55
- addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...")
56
- with gr.Row():
57
- ip_sld = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight")
58
- cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
59
  step_sld = gr.Slider(10,50,20,step=1,label="Steps")
60
- w_sld = gr.Slider(512,1024,512,step=64,label="幅")
61
- h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
62
- up_ck = gr.Checkbox(label="アップスケール",value=True)
63
- up_fac = gr.Slider(1,8,2,step=1,label="倍率")
64
  btn = gr.Button("生成",variant="primary")
65
  with gr.Column():
66
  out_img = gr.Image(label="結果")
 
 
67
  demo.queue()
 
 
 
 
 
 
68
 
69
- # ダミー推論(実装は省略)
70
- def generate_ui(*args, **kwargs):
71
- return Image.new("RGB", (512,768), (127,127,127))
72
- btn.click(generate_ui,
73
- inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,
74
- w_sld,h_sld,up_ck,up_fac],
75
- outputs=[out_img])
76
-
77
- ##############################################################################
78
- # 2. FastAPI ラッパー(REST API)
79
- ##############################################################################
80
  app = FastAPI()
81
 
 
82
  @app.post("/api/predict")
83
- async def predict(face: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
 
 
 
84
  try:
85
- img = Image.open(face.file)
 
 
 
 
 
 
 
 
86
  buffered = io.BytesIO()
87
- img.save(buffered, format="PNG")
88
- img_b64 = base64.b64encode(buffered.getvalue()).decode()
89
- return {"image_base64": img_b64}
 
90
  except Exception as e:
91
  traceback.print_exc()
92
  raise HTTPException(status_code=500, detail=str(e))
93
 
94
- # GradioFastAPI にマウント
95
  app = gr.mount_gradio_app(app, demo, path="/")
96
 
97
  print("Application startup script finished. Waiting for requests.")
98
-
99
- ##############################################################################
100
- # 3. Uvicorn 起動(ポート重複時フォールバック)
101
- ##############################################################################
102
  if __name__ == "__main__":
103
  import uvicorn
104
  port_env = int(os.getenv("PORT", "7860"))
@@ -109,4 +226,4 @@ if __name__ == "__main__":
109
  print(f"⚠️ Port {port_env} busy → falling back to 7860")
110
  uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info")
111
  else:
112
- raise
 
1
+ # app.py — ZeroGPU対応版
2
  import gradio as gr
3
  import spaces
4
  import torch
 
10
  import base64
11
  import io
12
  from pathlib import Path
13
+
14
+ # FastAPI関連(ハイブリッド構成のため維持)
15
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
 
17
+ # グローバル変数としてパイプラインを定義(初期値はNone)
18
+ pipe = None
19
+ face_app = None
20
+ upsampler = None
21
+ UPSCALE_OK = False
22
+
23
+ # 0. Cache dir & helpers (起動時に実行)
24
  PERSIST_BASE = Path("/data")
25
+ CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
26
+ else Path.home() / ".cache" / "instantid_cache")
27
+ MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR = CACHE_ROOT/"models", CACHE_ROOT/"models"/"Lora", CACHE_ROOT/"embeddings", CACHE_ROOT/"realesrgan"
28
+
29
+ for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
30
+ p.mkdir(parents=True, exist_ok=True)
 
31
 
32
  def dl(url: str, dst: Path, attempts: int = 2):
33
+ if dst.exists(): return
 
 
34
  for i in range(1, attempts + 1):
35
  print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
36
+ if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return
 
37
  raise RuntimeError(f"download failed → {url}")
38
 
39
+ # 1. Asset download (起動時に実行)
40
  print("— Starting asset download check —")
41
  BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
42
  dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
 
46
  dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
47
  print("— Asset download check finished —")
48
 
49
+
50
+ # 2. パイプライン初期化関数 (GPU確保後に呼び出される)
51
+ def initialize_pipelines():
52
+ global pipe, face_app, upsampler, UPSCALE_OK
53
+
54
+ # torch/diffusers/onnxruntimeなどのインポートを関数内に移動
55
+ from diffusers import StableDiffusionPipeline, ControlNetModel, DPMSolverMultistepScheduler, AutoencoderKL
56
+ from insightface.app import FaceAnalysis
57
+
58
+ print("--- Initializing Pipelines (GPU is now available) ---")
59
+
60
+ device = torch.device("cuda") # ZeroGPUではGPUが保証されている
61
+ dtype = torch.float16
62
+
63
+ # FaceAnalysis
64
+ if face_app is None:
65
+ print("Initializing FaceAnalysis...")
66
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
67
+ face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
68
+ face_app.prepare(ctx_id=0, det_size=(640, 640))
69
+ print("FaceAnalysis initialized.")
70
+
71
+ # Main Pipeline
72
+ if pipe is None:
73
+ print("Loading ControlNet...")
74
+ controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype)
75
+
76
+ print("Loading StableDiffusionPipeline...")
77
+ pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2)
78
+
79
+ print("Moving pipeline to GPU...")
80
+ pipe.to(device) # .to(device)をここで呼ぶ
81
+
82
+ print("Loading VAE...")
83
+ pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device)
84
+ pipe.controlnet = controlnet
85
+
86
+ print("Configuring Scheduler...")
87
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
88
+
89
+ print("Loading IP-Adapter and LoRA...")
90
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=IP_BIN_FILE.name)
91
+ pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
92
+
93
+ pipe.set_ip_adapter_scale(0.65)
94
+ print("Main pipeline initialized.")
95
+
96
+ # Upscaler
97
+ if upsampler is None and not UPSCALE_OK: # 一度失敗したら再試行しない
98
+ print("Checking for Upscaler...")
99
+ try:
100
+ from basicsr.archs.rrdb_arch import RRDBNet
101
+ from realesrgan import RealESRGAN
102
+ rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8)
103
+ upsampler = RealESRGAN(device, rrdb, scale=8)
104
+ upsampler.load_weights(str(UPSCALE_DIR / "RealESRGAN_x8plus.pth"))
105
+ UPSCALE_OK = True
106
+ print("Upscaler initialized successfully.")
107
+ except Exception as e:
108
+ UPSCALE_OK = False # 失敗を記録
109
+ print(f"Real-ESRGAN disabled → {e}")
110
+
111
+ print("--- All pipelines ready ---")
112
+
113
+
114
+ # 4. Core generation logic
115
+ BASE_PROMPT = ("(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n""photo of {subject},\n""cinematic lighting, golden hour, rim light, shallow depth of field,\n""textured skin, high detail, shot on Canon EOS R5, 85 mm f/1.4, ISO 200,\n""<lora:ip-adapter-faceid-plusv2_sd15_lora:0.65>, (face),\n""(aesthetic:1.1), (cinematic:0.8)")
116
+ NEG_PROMPT = ("ng_deepnegative_v1_75t, CyberRealistic_Negative-neg, UnrealisticDream, ""(worst quality:2), (low quality:1.8), lowres, (jpeg artifacts:1.2), ""painting, sketch, illustration, drawing, cartoon, anime, cgi, render, 3d, ""monochrome, grayscale, text, logo, watermark, signature, username, ""(MajicNegative_V2:0.8), bad hands, extra digits, fused fingers, malformed limbs, ""missing arms, missing legs, (badhandv4:0.7), BadNegAnatomyV1-neg, skin blemishes, acnes, age spot, glans")
117
+
118
+
119
+ # ZeroGPUで実行される本体。durationを60秒に設定。
120
+ @spaces.GPU(duration=60)
121
+ def _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
122
+ # 初回呼び出し時にパイプラインを初期化
123
+ initialize_pipelines()
124
+
125
+ progress(0, desc="Generating image...")
126
+ prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman"))
127
+ if add_prompt: prompt += ", " + add_prompt
128
+ neg = NEG_PROMPT + (", " + add_neg if add_neg else "")
129
+ pipe.set_ip_adapter_scale(ip_scale)
130
+
131
+ result = pipe(prompt=prompt, negative_prompt=neg, ip_adapter_image=face_img, image=face_img, controlnet_conditioning_scale=0.9, num_inference_steps=int(steps) + 5, guidance_scale=cfg, width=int(w), height=int(h)).images[0]
132
+
133
+ if upscale and UPSCALE_OK:
134
+ progress(0.8, desc="Upscaling...")
135
+ up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor)
136
+ result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
137
+
138
+ return result
139
+
140
+ # GradioのUIから呼び出されるラッパー関数
141
+ def generate_ui(face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
142
+ if face_np is None: raise gr.Error("顔画像をアップロードしてください。")
143
+ # NumPy配列をPillow画像に変換
144
+ face_img = Image.fromarray(face_np)
145
+ return _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress)
146
+
147
+
148
+ # 5. Gradio UI Definition
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("# InstantID – Beautiful Realistic Asians v7 (ZeroGPU)")
151
+ with gr.Row():
152
  with gr.Column():
153
+ face_in = gr.Image(label="顔写真",type="numpy")
154
+ subj_in = gr.Textbox(label="被写体説明",placeholder="e.g. woman in black suit, smiling")
155
+ add_in = gr.Textbox(label="追加プロンプト")
156
+ addneg_in = gr.Textbox(label="追加ネガティブ")
157
+ with gr.Accordion("詳細設定", open=False):
158
+ ip_sld = gr.Slider(0,1.5,0.65,step=0.05,label="IPAdapter scale")
159
+ cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
160
  step_sld = gr.Slider(10,50,20,step=1,label="Steps")
161
+ w_sld = gr.Slider(512,1024,512,step=64,label="幅")
162
+ h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
163
+ up_ck = gr.Checkbox(label="アップスケール",value=True)
164
+ up_fac = gr.Slider(1,8,2,step=1,label="倍率")
165
  btn = gr.Button("生成",variant="primary")
166
  with gr.Column():
167
  out_img = gr.Image(label="結果")
168
+
169
+ # .queue() はGradioの通常機能として必要
170
  demo.queue()
171
+
172
+ btn.click(
173
+ fn=generate_ui,
174
+ inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac],
175
+ outputs=out_img
176
+ )
177
 
178
+ # 6. FastAPI Mounting
 
 
 
 
 
 
 
 
 
 
179
  app = FastAPI()
180
 
181
+ # FastAPIのエンドポイントを定義。こちらも内部で_generate_coreを呼ぶ
182
  @app.post("/api/predict")
183
+ async def predict_endpoint(
184
+ face_image: UploadFile = File(...),
185
+ subject: str = Form("a woman"),
186
+ add_prompt: str = Form(""),
187
+ add_neg: str = Form(""),
188
+ cfg: float = Form(6.0),
189
+ ip_scale: float = Form(0.65),
190
+ steps: int = Form(20),
191
+ w: int = Form(512),
192
+ h: int = Form(768),
193
+ upscale: bool = Form(True),
194
+ up_factor: float = Form(2.0)
195
+ ):
196
  try:
197
+ contents = await face_image.read()
198
+ pil_image = Image.open(io.BytesIO(contents))
199
+
200
+ # FastAPI経由の呼び出しも同じコア関数を利用
201
+ result_pil_image = _generate_core(
202
+ pil_image, subject, add_prompt, add_neg, cfg, ip_scale,
203
+ steps, w, h, upscale, up_factor
204
+ )
205
+
206
  buffered = io.BytesIO()
207
+ result_pil_image.save(buffered, format="PNG")
208
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
209
+
210
+ return {"image_base64": img_str}
211
  except Exception as e:
212
  traceback.print_exc()
213
  raise HTTPException(status_code=500, detail=str(e))
214
 
215
+ # GradioアプリをFastAPIアプリにマウント
216
  app = gr.mount_gradio_app(app, demo, path="/")
217
 
218
  print("Application startup script finished. Waiting for requests.")
 
 
 
 
219
  if __name__ == "__main__":
220
  import uvicorn
221
  port_env = int(os.getenv("PORT", "7860"))
 
226
  print(f"⚠️ Port {port_env} busy → falling back to 7860")
227
  uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info")
228
  else:
229
+ raise