i0switch commited on
Commit
45eb86f
·
verified ·
1 Parent(s): a304781

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -219
app.py CHANGED
@@ -1,252 +1,347 @@
1
- # app.py — ZeroGPU対応版
2
- import gradio as gr
3
- import spaces
4
- import torch
5
- import numpy as np
6
- from PIL import Image
 
 
 
 
 
 
 
7
  import os
 
 
8
  import subprocess
9
  import traceback
10
- import base64
11
- import io
12
  from pathlib import Path
 
13
 
14
- # FastAPI関連(ハイブリッド構成のため維持)
 
 
 
15
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # グローバル変数としてパイプラインを定義(初期値はNone)
18
- pipe = None
19
- face_app = None
20
- upsampler = None
21
- UPSCALE_OK = False
22
 
23
- # 0. Cache dir & helpers (起動時に実行)
24
  PERSIST_BASE = Path("/data")
25
- CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
26
- else Path.home() / ".cache" / "instantid_cache")
27
- MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR = CACHE_ROOT/"models", CACHE_ROOT/"models"/"Lora", CACHE_ROOT/"embeddings", CACHE_ROOT/"realesrgan"
 
 
 
 
 
 
 
28
 
29
- for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
30
- p.mkdir(parents=True, exist_ok=True)
31
 
32
- def dl(url: str, dst: Path, attempts: int = 2):
33
- if dst.exists(): return
 
 
34
  for i in range(1, attempts + 1):
35
- print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
36
- if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return
37
- raise RuntimeError(f"download failed → {url}")
38
-
39
- # 1. Asset download (起動時に実行)
40
- print("— Starting asset download check —")
41
- BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
42
- dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
43
- IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin"
44
- dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE)
45
- LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
46
- dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
47
- print("— Asset download check finished —")
48
-
49
-
50
- # 2. パイプライン初期化関数 (GPU確保後に呼び出される)
51
- def initialize_pipelines():
52
- global pipe, face_app, upsampler, UPSCALE_OK
53
-
54
- # torch/diffusers/onnxruntimeなどのインポートを関数内に移動
55
- from diffusers import StableDiffusionPipeline, ControlNetModel, DPMSolverMultistepScheduler, AutoencoderKL
56
- from insightface.app import FaceAnalysis
57
-
58
- print("--- Initializing Pipelines (GPU is now available) ---")
59
-
60
- device = torch.device("cuda") # ZeroGPUではGPUが保証されている
61
- dtype = torch.float16
62
-
63
- # FaceAnalysis
64
- if face_app is None:
65
- print("Initializing FaceAnalysis...")
66
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
67
- face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
68
- face_app.prepare(ctx_id=0, det_size=(640, 640))
69
- print("FaceAnalysis initialized.")
70
-
71
- # Main Pipeline
72
- if pipe is None:
73
- print("Loading ControlNet...")
74
- controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype)
75
-
76
- print("Loading StableDiffusionPipeline...")
77
- pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2)
78
-
79
- print("Moving pipeline to GPU...")
80
- pipe.to(device) # .to(device)をここで呼ぶ
81
-
82
- print("Loading VAE...")
83
- pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device)
84
- pipe.controlnet = controlnet
85
-
86
- print("Configuring Scheduler...")
87
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
88
-
89
- print("Loading IP-Adapter and LoRA...")
90
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=IP_BIN_FILE.name)
91
- pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
92
-
93
- pipe.set_ip_adapter_scale(0.65)
94
- print("Main pipeline initialized.")
95
-
96
- # Upscaler
97
- if upsampler is None and not UPSCALE_OK: # 一度失敗したら再試行しない
98
- print("Checking for Upscaler...")
99
  try:
100
- # cv2のインポートをここに追加
101
- import cv2
102
- from basicsr.archs.rrdb_arch import RRDBNet
103
- from realesrgan import RealESRGAN
104
- rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8)
105
- upsampler = RealESRGAN(device, rrdb, scale=8)
106
- upsampler.load_weights(str(UPSCALE_DIR / "RealESRGAN_x8plus.pth"))
107
- UPSCALE_OK = True
108
- print("Upscaler initialized successfully.")
109
- except Exception as e:
110
- UPSCALE_OK = False # 失敗を記録
111
- print(f"Real-ESRGAN disabled → {e}")
112
-
113
- print("--- All pipelines ready ---")
114
-
115
-
116
- # 4. Core generation logic
117
- BASE_PROMPT = ("(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n""photo of {subject},\n""cinematic lighting, golden hour, rim light, shallow depth of field,\n""textured skin, high detail, shot on Canon EOS R5, 85 mm f/1.4, ISO 200,\n""<lora:ip-adapter-faceid-plusv2_sd15_lora:0.65>, (face),\n""(aesthetic:1.1), (cinematic:0.8)")
118
- NEG_PROMPT = ("ng_deepnegative_v1_75t, CyberRealistic_Negative-neg, UnrealisticDream, ""(worst quality:2), (low quality:1.8), lowres, (jpeg artifacts:1.2), ""painting, sketch, illustration, drawing, cartoon, anime, cgi, render, 3d, ""monochrome, grayscale, text, logo, watermark, signature, username, ""(MajicNegative_V2:0.8), bad hands, extra digits, fused fingers, malformed limbs, ""missing arms, missing legs, (badhandv4:0.7), BadNegAnatomyV1-neg, skin blemishes, acnes, age spot, glans")
119
-
120
- # 【変更点①】内部的な画像生成関数。@spaces.GPUデコレータを外す
121
- def _generate_internal(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
122
- # 初回呼び出し時にパイプラインを初期化
123
- initialize_pipelines()
124
-
125
- progress(0, desc="Generating image...")
126
- prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman"))
127
- if add_prompt: prompt += ", " + add_prompt
128
- neg = NEG_PROMPT + (", " + add_neg if add_neg else "")
129
- pipe.set_ip_adapter_scale(ip_scale)
130
-
131
- result = pipe(prompt=prompt, negative_prompt=neg, ip_adapter_image=face_img, image=face_img, controlnet_conditioning_scale=0.9, num_inference_steps=int(steps) + 5, guidance_scale=cfg, width=int(w), height=int(h)).images[0]
132
-
133
- if upscale and UPSCALE_OK:
134
- # cv2のインポートをここにも追加
135
- import cv2
136
- progress(0.8, desc="Upscaling...")
137
- up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor)
138
- result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
139
-
140
- return result
141
-
142
- # 【変更点②】@spaces.GPUデコレータを持つ新しいラッパー関数を定義
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  @spaces.GPU(duration=60)
144
- def generate_gpu_wrapper(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
- Hugging Face SpacesプラットフォームにGPUを要求するためのラッパー関数。
147
- 実際の処理は _generate_internal を呼び出して実行する。
148
  """
149
- return _generate_internal(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress)
 
 
 
 
 
 
 
 
 
150
 
 
 
151
 
152
- # 【変更点③】GradioのUIから新しいラッパー関数を呼び出すように変更
153
- def generate_ui(face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
154
- if face_np is None: raise gr.Error("顔画像をアップロードしてください。")
155
- # NumPy配列をPillow画像に変換
156
- face_img = Image.fromarray(face_np)
157
- # _generate_coreの代わりにgenerate_gpu_wrapperを呼び出す
158
- return generate_gpu_wrapper(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress)
 
 
 
159
 
 
 
 
 
 
160
 
161
- # 5. Gradio UI Definition
162
- with gr.Blocks() as demo:
163
- gr.Markdown("# InstantID Beautiful Realistic Asians v7 (ZeroGPU)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  with gr.Row():
165
- with gr.Column():
166
- face_in = gr.Image(label="顔写真",type="numpy")
167
- subj_in = gr.Textbox(label="被写体説明",placeholder="e.g. woman in black suit, smiling")
168
- add_in = gr.Textbox(label="追加プロンプト")
169
- addneg_in = gr.Textbox(label="追加ネガティブ")
170
- with gr.Accordion("詳細設定", open=False):
171
- ip_sld = gr.Slider(0,1.5,0.65,step=0.05,label="IP‑Adapter scale")
172
- cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
173
- step_sld = gr.Slider(10,50,20,step=1,label="Steps")
174
- w_sld = gr.Slider(512,1024,512,step=64,label="幅")
175
- h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
176
- up_ck = gr.Checkbox(label="アップスケール",value=True)
177
- up_fac = gr.Slider(1,8,2,step=1,label="倍率")
178
- btn = gr.Button("生成",variant="primary")
179
- with gr.Column():
180
- out_img = gr.Image(label="結果")
181
-
182
- demo.queue()
183
-
184
- btn.click(
185
- fn=generate_ui,
186
- inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac],
187
- outputs=out_img
 
 
 
 
188
  )
189
 
190
- # 6. FastAPI Mounting
 
 
 
191
  app = FastAPI()
192
 
193
- # 【変更点④】FastAPIのエンドポイントも新しいラッパー関数を呼び出すように変更
194
- @app.post("/api/predict")
195
- async def predict_endpoint(
196
- face_image: UploadFile = File(...),
197
- subject: str = Form("a woman"),
198
- add_prompt: str = Form(""),
199
- add_neg: str = Form(""),
200
- cfg: float = Form(6.0),
201
- ip_scale: float = Form(0.65),
202
- steps: int = Form(20),
203
- w: int = Form(512),
204
  h: int = Form(768),
205
- upscale: bool = Form(True),
206
- up_factor: float = Form(2.0)
207
  ):
208
  try:
209
- contents = await face_image.read()
210
- pil_image = Image.open(io.BytesIO(contents))
211
-
212
- # _generate_coreの代わりにgenerate_gpu_wrapperを呼び出す
213
- result_pil_image = generate_gpu_wrapper(
214
- pil_image, subject, add_prompt, add_neg, cfg, ip_scale,
215
- steps, w, h, upscale, up_factor
 
 
 
 
 
 
 
216
  )
217
-
218
- buffered = io.BytesIO()
219
- result_pil_image.save(buffered, format="PNG")
220
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
221
-
222
- return {"image_base64": img_str}
223
  except Exception as e:
224
  traceback.print_exc()
225
  raise HTTPException(status_code=500, detail=str(e))
226
 
227
- # GradioアプリをFastAPIアプリにマウント
228
- app = gr.mount_gradio_app(app, demo, path="/")
229
-
230
- print("Application startup script finished. Waiting for requests.")
231
- if __name__ == "__main__":
232
- import os, time, socket, uvicorn
233
-
234
- def port_is_free(port: int) -> bool:
235
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
236
- return s.connect_ex(("0.0.0.0", port)) != 0
237
-
238
- port = int(os.getenv("PORT", 7860))
239
- # ローカルでのテスト用にタイムアウトを短縮
240
- timeout_sec = 30
241
- poll_interval = 2
242
-
243
- t0 = time.time()
244
- while not port_is_free(port):
245
- waited = time.time() - t0
246
- if waited >= timeout_sec:
247
- raise RuntimeError(f"Port {port} is still busy after {timeout_sec}s")
248
- print(f"⚠️ Port {port} busy, retrying in {poll_interval}s …")
249
- time.sleep(poll_interval)
250
-
251
- # Hugging Face Spaces環境ではポートの競合は起こりにくいため、ポートチェックロジックを簡略化・無効化
252
- uvicorn.run(app, host="0.0.0.0", port=port, workers=1, log_level="info")
 
1
+ # app.py — InstantID × Beautiful Realistic Asians v7 (ZeroGPU-ready, FastAPI + Gradio)
2
+ # 2025-06-21
3
+ #
4
+ # ───────────────────────────────────────────────────────────────
5
+ # 主な特徴
6
+ # • @spaces.GPU(duration=60) を公開名 generate_core() に付与
7
+ # • パイプラインは lazy-load で初回推論時に GPU へロード
8
+ # • モデル資産は /data または ~/.cache に永続化
9
+ # • Real-ESRGAN アップスケール (x4 / x8) オプション
10
+ # • Gradio UI + FastAPI REST を 1 プロセスで共存
11
+ # • Uvicorn 手動起動は不要(Spaces が自前で立てる)
12
+ # ───────────────────────────────────────────────────────────────
13
+
14
  import os
15
+ import io
16
+ import base64
17
  import subprocess
18
  import traceback
 
 
19
  from pathlib import Path
20
+ from typing import Optional
21
 
22
+ import numpy as np
23
+ import torch
24
+ import gradio as gr
25
+ import spaces
26
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
27
+ from PIL import Image
28
+
29
+ from diffusers import (
30
+ StableDiffusionControlNetPipeline,
31
+ ControlNetModel,
32
+ DPMSolverMultistepScheduler,
33
+ AutoencoderKL,
34
+ )
35
+ from diffusers.loaders import AttnProcsLayers
36
+ from insightface.app import FaceAnalysis
37
+ from basicsr.utils.download_util import load_file_from_url
38
+ from realesrgan import RealESRGANer
39
 
40
+ # ==============================================================
41
+ # 0. キャッシュディレクトリとダウンローダ
42
+ # ==============================================================
 
 
43
 
 
44
  PERSIST_BASE = Path("/data")
45
+ CACHE_ROOT = (
46
+ PERSIST_BASE / "instantid_cache"
47
+ if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
48
+ else Path.home() / ".cache" / "instantid_cache"
49
+ )
50
+ MODELS_DIR = CACHE_ROOT / "models"
51
+ LORA_DIR = CACHE_ROOT / "lora"
52
+ UPSCALE_DIR = CACHE_ROOT / "realesrgan"
53
+ for _p in (MODELS_DIR, LORA_DIR, UPSCALE_DIR):
54
+ _p.mkdir(parents=True, exist_ok=True)
55
 
 
 
56
 
57
+ def download(url: str, dst: Path, attempts: int = 2):
58
+ """単純リトライ付きダウンローダ(curl or basicsr fallback)"""
59
+ if dst.exists():
60
+ return dst
61
  for i in range(1, attempts + 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
+ subprocess.check_call(["curl", "-L", "-o", str(dst), url])
64
+ return dst
65
+ except subprocess.CalledProcessError:
66
+ print(f"[DL] Retry {i}/{attempts} failed: {url}")
67
+ # 最後に basicsr のダウンローダでフォールバック
68
+ load_file_from_url(url=url, model_dir=str(dst.parent), file_name=dst.name)
69
+ return dst
70
+
71
+
72
+ # ==============================================================
73
+ # 1. モデル URL 定義
74
+ # ==============================================================
75
+
76
+ BRA_V7_URL = (
77
+ "https://huggingface.co/i0switch-assets/Beautiful_Realistic_Asians_v7/"
78
+ "resolve/main/beautiful_realistic_asians_v7_fp16.safetensors"
79
+ )
80
+ IP_ADAPTER_BIN_URL = (
81
+ "https://huggingface.co/h94/IP-Adapter/resolve/main/ip-adapter-plus-face_sd15.bin"
82
+ )
83
+ IP_ADAPTER_LORA_URL = (
84
+ "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/"
85
+ "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
86
+ )
87
+ REALESRGAN_URL = (
88
+ "https://huggingface.co/aimagelab/realesrgan/resolve/main/RealESRGAN_x4plus.pth"
89
+ )
90
+
91
+ # ==============================================================
92
+ # 2. グローバル変数(lazy-load される)
93
+ # ==============================================================
94
+
95
+ pipe: Optional[StableDiffusionControlNetPipeline] = None
96
+ face_analyser: Optional[FaceAnalysis] = None
97
+ upsampler: Optional[RealESRGANer] = None
98
+
99
+ # ==============================================================
100
+ # 3. パイプライン初期化
101
+ # ==============================================================
102
+
103
+
104
+ def initialize_pipelines():
105
+ global pipe, face_analyser, upsampler
106
+
107
+ if pipe is not None:
108
+ return # 既に初期化済み
109
+
110
+ print("[INIT] Downloading model assets …")
111
+
112
+ # ---- 3-1. 基本モデル & IP-Adapter ----
113
+ bra_ckpt = download(BRA_V7_URL, MODELS_DIR / "bra_v7.safetensors")
114
+ ip_bin = download(IP_ADAPTER_BIN_URL, MODELS_DIR / "ip_adapter.bin")
115
+ ip_lora = download(IP_ADAPTER_LORA_URL, LORA_DIR / "ip_adapter_faceid.lora")
116
+
117
+ # ---- 3-2. ControlNet (InstantID) ----
118
+ controlnet = ControlNetModel.from_pretrained(
119
+ "InstantID/ControlNet-Mediapipe-Face",
120
+ torch_dtype=torch.float16,
121
+ cache_dir=str(MODELS_DIR),
122
+ )
123
+
124
+ # ---- 3-3. Diffusers パイプライン ----
125
+ pipe_local_files_only = {
126
+ "controlnet": controlnet,
127
+ "vae": AutoencoderKL.from_pretrained(
128
+ "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16
129
+ ),
130
+ "torch_dtype": torch.float16,
131
+ "safety_checker": None,
132
+ }
133
+ pipe_base = "runwayml/stable-diffusion-v1-5"
134
+ pipe_kwargs = dict(
135
+ local_files_only=False,
136
+ cache_dir=str(MODELS_DIR),
137
+ load_safety_checker=False,
138
+ )
139
+ pipe_tmp = StableDiffusionControlNetPipeline.from_pretrained(
140
+ pipe_base, **pipe_local_files_only, **pipe_kwargs
141
+ )
142
+ pipe_tmp.scheduler = DPMSolverMultistepScheduler.from_pretrained(
143
+ pipe_base, subfolder="scheduler", cache_dir=str(MODELS_DIR)
144
+ )
145
+ # LoRA / IP-Adapter
146
+ pipe_tmp.load_ip_adapter(
147
+ ip_bin,
148
+ subfolder=None,
149
+ weight_name=None,
150
+ )
151
+ ip_layers = AttnProcsLayers(pipe_tmp.unet.attn_processors)
152
+ ip_layers.load_lora_weights(ip_lora, adapter_name="ip_faceid", safe_load=True)
153
+ pipe_tmp.set_adapters(["ip_faceid"], adapter_weights=[0.6])
154
+ pipe_tmp.to("cuda")
155
+
156
+ pipe = pipe_tmp
157
+
158
+ # ---- 3-4. InsightFace ----
159
+ face_analyser = FaceAnalysis(
160
+ name="buffalo_l", root=str(MODELS_DIR), providers=["CUDAExecutionProvider"]
161
+ )
162
+ face_analyser.prepare(ctx_id=0, det_size=(640, 640))
163
+
164
+ # ---- 3-5. Real-ESRGAN ----
165
+ esrgan_ckpt = download(REALESRGAN_URL, UPSCALE_DIR / "realesrgan_x4plus.pth")
166
+ upsampler = RealESRGANer(
167
+ scale=4,
168
+ model_path=str(esrgan_ckpt),
169
+ half=True,
170
+ tile=512,
171
+ tile_pad=10,
172
+ pre_pad=0,
173
+ gpu_id=0,
174
+ )
175
+
176
+ print("[INIT] Pipelines ready.")
177
+
178
+
179
+ # ==============================================================
180
+ # 4. プロンプト設定
181
+ # ==============================================================
182
+
183
+ BASE_PROMPT = (
184
+ "(masterpiece:1.2), best quality, ultra-realistic, 8k, RAW photo, "
185
+ "cinematic lighting, textured skin, "
186
+ )
187
+ NEG_PROMPT = (
188
+ "verybadimagenegative_v1.3, ng_deepnegative_v1_75t, "
189
+ "(worst quality:2), (low quality:2), lowres, blurry, bad anatomy, "
190
+ "bad hands, extra digits, cropped, watermark, signature"
191
+ )
192
+
193
+ # ==============================================================
194
+ # 5. 生成コア関数(GPU を掴む)
195
+ # ==============================================================
196
+
197
+
198
  @spaces.GPU(duration=60)
199
+ def generate_core(
200
+ face_img: Image.Image,
201
+ subject: str,
202
+ add_prompt: str = "",
203
+ add_neg: str = "",
204
+ cfg: float = 7.5,
205
+ ip_scale: float = 0.6,
206
+ steps: int = 30,
207
+ w: int = 768,
208
+ h: int = 768,
209
+ upscale: bool = False,
210
+ up_factor: int = 4,
211
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
212
+ ):
213
  """
214
+ メイン生成関数
 
215
  """
216
+ try:
217
+ if pipe is None:
218
+ initialize_pipelines()
219
+
220
+ face_np = np.array(face_img)
221
+ face_info = face_analyser.get(face_np)
222
+ if len(face_info) == 0:
223
+ raise ValueError("顔が検出できませんでした。")
224
+
225
+ pipe.set_adapters(["ip_faceid"], adapter_weights=[ip_scale])
226
 
227
+ prompt = BASE_PROMPT + subject + ", " + add_prompt
228
+ negative = NEG_PROMPT + ", " + add_neg
229
 
230
+ result = pipe(
231
+ prompt=prompt,
232
+ negative_prompt=negative,
233
+ num_inference_steps=int(steps),
234
+ guidance_scale=float(cfg),
235
+ image=face_img,
236
+ control_image=None,
237
+ width=int(w),
238
+ height=int(h),
239
+ ).images[0]
240
 
241
+ if upscale and upsampler is not None:
242
+ scale = 4 if up_factor == 4 else 8
243
+ upsampler.scale = scale
244
+ result, _ = upsampler.enhance(np.array(result))
245
+ result = Image.fromarray(result)
246
 
247
+ return result
248
+
249
+ except Exception as e:
250
+ traceback.print_exc()
251
+ raise e
252
+
253
+
254
+ # ==============================================================
255
+ # 6. Gradio UI
256
+ # ==============================================================
257
+
258
+ with gr.Blocks(title="InstantID × BRA v7 (ZeroGPU)") as demo:
259
+ gr.Markdown("## InstantID × Beautiful Realistic Asians v7")
260
+ with gr.Row():
261
+ face_img = gr.Image(type="pil", label="Face ID", sources=["upload"])
262
+ subject = gr.Textbox(
263
+ label="被写体説明(例: '30代日本人女性、黒髪セミロング')", interactive=True
264
+ )
265
+ add_prompt = gr.Textbox(label="追加プロンプト", interactive=True)
266
+ add_neg = gr.Textbox(label="追加ネガティブ", interactive=True)
267
+ with gr.Row():
268
+ cfg = gr.Slider(1, 20, value=7.5, step=0.5, label="CFG Scale")
269
+ ip_scale = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IP-Adapter Weight")
270
  with gr.Row():
271
+ steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
272
+ w = gr.Slider(512, 1024, value=768, step=64, label="Width")
273
+ h = gr.Slider(512, 1024, value=768, step=64, label="Height")
274
+ with gr.Row():
275
+ upscale = gr.Checkbox(label="Real-ESRGAN Upscale", value=False)
276
+ up_factor = gr.Radio([4, 8], value=4, label="Upscale Factor")
277
+ run_btn = gr.Button("Generate")
278
+
279
+ output_img = gr.Image(type="pil", label="Result")
280
+
281
+ run_btn.click(
282
+ fn=generate_core,
283
+ inputs=[
284
+ face_img,
285
+ subject,
286
+ add_prompt,
287
+ add_neg,
288
+ cfg,
289
+ ip_scale,
290
+ steps,
291
+ w,
292
+ h,
293
+ upscale,
294
+ up_factor,
295
+ ],
296
+ outputs=output_img,
297
+ show_progress=True,
298
  )
299
 
300
+ # ==============================================================
301
+ # 7. FastAPI エンドポイント
302
+ # ==============================================================
303
+
304
  app = FastAPI()
305
 
306
+
307
+ @app.post("/api/generate")
308
+ async def api_generate(
309
+ subject: str = Form(...),
310
+ cfg: float = Form(7.5),
311
+ steps: int = Form(30),
312
+ ip_scale: float = Form(0.6),
313
+ w: int = Form(768),
 
 
 
314
  h: int = Form(768),
315
+ file: UploadFile = File(...),
 
316
  ):
317
  try:
318
+ img_bytes = await file.read()
319
+ pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
320
+ res = generate_core(
321
+ face_img=pil,
322
+ subject=subject,
323
+ add_prompt="",
324
+ add_neg="",
325
+ cfg=cfg,
326
+ ip_scale=ip_scale,
327
+ steps=steps,
328
+ w=w,
329
+ h=h,
330
+ upscale=False,
331
+ up_factor=4,
332
  )
333
+ buf = io.BytesIO()
334
+ res.save(buf, format="PNG")
335
+ b64 = base64.b64encode(buf.getvalue()).decode()
336
+ return {"image": f"data:image/png;base64,{b64}"}
 
 
337
  except Exception as e:
338
  traceback.print_exc()
339
  raise HTTPException(status_code=500, detail=str(e))
340
 
341
+
342
+ # ==============================================================
343
+ # 8. Launch
344
+ # ==============================================================
345
+
346
+ # Spaces が自動で Uvicorn を起動するため、手動起動は不要。
347
+ demo.queue(concurrency_count=2).launch(share=False)