i0switch commited on
Commit
942bdcb
·
verified ·
1 Parent(s): 92aaea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -142
app.py CHANGED
@@ -14,22 +14,21 @@ from pathlib import Path
14
  # FastAPI関連(ハイブリッド構成のため維持)
15
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
 
17
- # グローバル変数としてパイプラインを定義(初期値はNone)
18
- pipe = None
19
- face_app = None
20
- upsampler = None
21
- UPSCALE_OK = False
22
-
23
- # 0. Cache dir & helpers (起動時に実行)
24
  PERSIST_BASE = Path("/data")
25
- CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
26
- else Path.home() / ".cache" / "instantid_cache")
27
- MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR = CACHE_ROOT/"models", CACHE_ROOT/"models"/"Lora", CACHE_ROOT/"embeddings", CACHE_ROOT/"realesrgan"
28
-
29
- for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
30
- p.mkdir(parents=True, exist_ok=True)
 
31
 
32
  def dl(url: str, dst: Path, attempts: int = 2):
 
33
  if dst.exists(): return
34
  for i in range(1, attempts + 1):
35
  print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
@@ -48,120 +47,72 @@ print("— Asset download check finished —")
48
 
49
 
50
  # 2. パイプライン初期化関数 (GPU確保後に呼び出される)
51
- def initialize_pipelines():
52
- global pipe, face_app, upsampler, UPSCALE_OK
53
-
54
- # torch/diffusers/onnxruntimeなどのインポートを関数内に移動
55
- from diffusers import StableDiffusionPipeline, ControlNetModel, DPMSolverMultistepScheduler, AutoencoderKL
56
  from insightface.app import FaceAnalysis
57
 
58
- print("--- Initializing Pipelines (GPU is now available) ---")
59
-
60
- device = torch.device("cuda") # ZeroGPUではGPUが保証されている
61
- dtype = torch.float16
62
-
63
- # FaceAnalysis
64
- if face_app is None:
65
- print("Initializing FaceAnalysis...")
66
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
67
- face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
68
- face_app.prepare(ctx_id=0, det_size=(640, 640))
69
- print("FaceAnalysis initialized.")
70
-
71
- # Main Pipeline
72
- if pipe is None:
73
- print("Loading ControlNet...")
74
- controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype)
75
-
76
- print("Loading StableDiffusionPipeline...")
77
- pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2)
78
-
79
- print("Moving pipeline to GPU...")
80
- pipe.to(device) # .to(device)をここで呼ぶ
81
-
82
- print("Loading VAE...")
83
- pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device)
84
- pipe.controlnet = controlnet
85
-
86
- print("Configuring Scheduler...")
87
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
88
-
89
- print("Loading IP-Adapter and LoRA...")
90
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=IP_BIN_FILE.name)
91
- pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
92
-
93
- pipe.set_ip_adapter_scale(0.65)
94
- print("Main pipeline initialized.")
95
-
96
- # Upscaler
97
- if upsampler is None and not UPSCALE_OK: # 一度失敗したら再試行しない
98
- print("Checking for Upscaler...")
99
- try:
100
- from basicsr.archs.rrdb_arch import RRDBNet
101
- from realesrgan import RealESRGAN
102
- rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8)
103
- upsampler = RealESRGAN(device, rrdb, scale=8)
104
- upsampler.load_weights(str(UPSCALE_DIR / "RealESRGAN_x8plus.pth"))
105
- UPSCALE_OK = True
106
- print("Upscaler initialized successfully.")
107
- except Exception as e:
108
- UPSCALE_OK = False # 失敗を記録
109
- print(f"Real-ESRGAN disabled → {e}")
110
-
111
- print("--- All pipelines ready ---")
112
-
113
-
114
- # 4. Core generation logic
115
- BASE_PROMPT = ("(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n""photo of {subject},\n""cinematic lighting, golden hour, rim light, shallow depth of field,\n""textured skin, high detail, shot on Canon EOS R5, 85 mm f/1.4, ISO 200,\n""<lora:ip-adapter-faceid-plusv2_sd15_lora:0.65>, (face),\n""(aesthetic:1.1), (cinematic:0.8)")
116
- NEG_PROMPT = ("ng_deepnegative_v1_75t, CyberRealistic_Negative-neg, UnrealisticDream, ""(worst quality:2), (low quality:1.8), lowres, (jpeg artifacts:1.2), ""painting, sketch, illustration, drawing, cartoon, anime, cgi, render, 3d, ""monochrome, grayscale, text, logo, watermark, signature, username, ""(MajicNegative_V2:0.8), bad hands, extra digits, fused fingers, malformed limbs, ""missing arms, missing legs, (badhandv4:0.7), BadNegAnatomyV1-neg, skin blemishes, acnes, age spot, glans")
117
 
118
-
119
- # ZeroGPUで実行される本体。durationを60秒に設定。
120
- @spaces.GPU(duration=60)
121
- def _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
122
- # 初回呼び出し時にパイプラインを初期化
123
- initialize_pipelines()
124
-
125
- progress(0, desc="Generating image...")
126
- prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman"))
127
- if add_prompt: prompt += ", " + add_prompt
128
- neg = NEG_PROMPT + (", " + add_neg if add_neg else "")
129
- pipe.set_ip_adapter_scale(ip_scale)
130
-
131
- result = pipe(prompt=prompt, negative_prompt=neg, ip_adapter_image=face_img, image=face_img, controlnet_conditioning_scale=0.9, num_inference_steps=int(steps) + 5, guidance_scale=cfg, width=int(w), height=int(h)).images[0]
132
-
133
- if upscale and UPSCALE_OK:
134
- progress(0.8, desc="Upscaling...")
135
- up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor)
136
- result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
137
-
138
- return result
139
-
140
- # GradioのUIから呼び出されるラッパー関数
141
- def generate_ui(face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)):
142
- if face_np is None: raise gr.Error("顔画像をアップロードしてください。")
143
- # NumPy配列をPillow画像に変換
144
- face_img = Image.fromarray(face_np)
145
- return _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress)
146
-
147
-
148
- # 5. Gradio UI Definition
149
- with gr.Blocks() as demo:
150
- gr.Markdown("# InstantID – Beautiful Realistic Asians v7 (ZeroGPU)")
151
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
152
  with gr.Column():
153
- face_in = gr.Image(label="顔写真",type="numpy")
154
- subj_in = gr.Textbox(label="被写体説明",placeholder="e.g. woman in black suit, smiling")
155
- add_in = gr.Textbox(label="追加プロンプト")
156
- addneg_in = gr.Textbox(label="追加ネガティブ")
157
- with gr.Accordion("詳細設定", open=False):
158
- ip_sld = gr.Slider(0,1.5,0.65,step=0.05,label="IPAdapter scale")
159
- cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
160
  step_sld = gr.Slider(10,50,20,step=1,label="Steps")
161
- w_sld = gr.Slider(512,1024,512,step=64,label="幅")
162
- h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
163
- up_ck = gr.Checkbox(label="アップスケール",value=True)
164
- up_fac = gr.Slider(1,8,2,step=1,label="倍率")
165
  btn = gr.Button("生成",variant="primary")
166
  with gr.Column():
167
  out_img = gr.Image(label="結果")
@@ -169,39 +120,38 @@ with gr.Blocks() as demo:
169
  # .queue() はGradioの通常機能として必要
170
  demo.queue()
171
 
 
 
 
 
172
  btn.click(
173
  fn=generate_ui,
174
  inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac],
175
- outputs=out_img
176
  )
177
 
178
- # 6. FastAPI Mounting
 
 
179
  app = FastAPI()
180
 
181
- # FastAPIのエンドポイントを定義。こちらも内部で_generate_coreを呼ぶ
182
  @app.post("/api/predict")
183
- async def predict_endpoint(
184
- face_image: UploadFile = File(...),
185
- subject: str = Form("a woman"),
186
  add_prompt: str = Form(""),
187
  add_neg: str = Form(""),
188
  cfg: float = Form(6.0),
189
- ip_scale: float = Form(0.65),
190
  steps: int = Form(20),
191
  w: int = Form(512),
192
  h: int = Form(768),
193
  upscale: bool = Form(True),
194
- up_factor: float = Form(2.0)
195
  ):
196
  try:
197
- contents = await face_image.read()
198
- pil_image = Image.open(io.BytesIO(contents))
199
-
200
- # FastAPI経由の呼び出しも同じコア関数を利用
201
- result_pil_image = _generate_core(
202
- pil_image, subject, add_prompt, add_neg, cfg, ip_scale,
203
- steps, w, h, upscale, up_factor
204
- )
205
 
206
  buffered = io.BytesIO()
207
  result_pil_image.save(buffered, format="PNG")
@@ -216,9 +166,12 @@ async def predict_endpoint(
216
  app = gr.mount_gradio_app(app, demo, path="/")
217
 
218
  print("Application startup script finished. Waiting for requests.")
219
- # app.py の末尾に追加
220
 
 
 
 
221
  if __name__ == "__main__":
222
- import uvicorn
223
- # SpacesでGradioアプリを動かす際の標準ポートは7860です
224
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
14
  # FastAPI関連(ハイブリッド構成のため維持)
15
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
 
17
+ ##############################################################################
18
+ # 0. 設定とヘルパー
19
+ ##############################################################################
20
+ # モデル・LoRA キャッシュを /data に置ける場合はそちらを優先
 
 
 
21
  PERSIST_BASE = Path("/data")
22
+ CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
23
+ and os.access(PERSIST_BASE, os.W_OK)
24
+ else Path.home() / ".cache" / "instantid_cache")
25
+ MODELS_DIR = CACHE_ROOT / "models"
26
+ LORA_DIR = CACHE_ROOT / "lora"
27
+ for d in (MODELS_DIR, LORA_DIR):
28
+ d.mkdir(parents=True, exist_ok=True)
29
 
30
  def dl(url: str, dst: Path, attempts: int = 2):
31
+ """冪等ダウンロード(既存ならスキップ、リトライ付き)"""
32
  if dst.exists(): return
33
  for i in range(1, attempts + 1):
34
  print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
 
47
 
48
 
49
  # 2. パイプライン初期化関数 (GPU確保後に呼び出される)
50
+ def load_pipeline():
51
+ from diffusers import (
52
+ StableDiffusionPipeline, ControlNetModel,
53
+ DPMSolverMultistepScheduler, AutoencoderKL,
54
+ )
55
  from insightface.app import FaceAnalysis
56
 
57
+ print(" Loading models to GPU ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # --- InstantID 主要モデル ---
60
+ vae = AutoencoderKL.from_pretrained(
61
+ "stabilityai/sd-vae-ft-mse",
62
+ torch_dtype=torch.float16
63
+ )
64
+ base = StableDiffusionPipeline.from_single_file(
65
+ str(BASE_CKPT),
66
+ vae=vae,
67
+ torch_dtype=torch.float16,
68
+ safety_checker=None,
69
+ original_config_file="v1-inference.yaml" # StableDiffusion1.x 互換
70
+ )
71
+ control = ControlNetModel.from_pretrained(
72
+ "lllyasviel/control_v11p_sd15_openpose",
73
+ torch_dtype=torch.float16
74
+ )
75
+ pipe = StableDiffusionPipeline(
76
+ vae=vae,
77
+ text_encoder=base.text_encoder,
78
+ tokenizer=base.tokenizer,
79
+ unet=base.unet,
80
+ controlnet=control,
81
+ scheduler=DPMSolverMultistepScheduler.from_config(base.scheduler.config),
82
+ safety_checker=None,
83
+ feature_extractor=base.feature_extractor,
84
+ requires_safety_checker=False
85
+ ).to("cuda", dtype=torch.float16)
86
+ pipe.load_lora_weights(str(LORA_FILE))
87
+ pipe.set_adapters(["ip_adapter_face"], [1.0])
88
+ pipe.enable_xformers_memory_efficient_attention()
89
+
90
+ # --- InsightFace ---
91
+ face_analyzer = FaceAnalysis(name="antelopev2", providers=["CUDAExecutionProvider"])
92
+ face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
93
+
94
+ print("✓ Model loading complete.")
95
+ return pipe, face_analyzer
96
+
97
+
98
+ ##############################################################################
99
+ # 3. Gradio UI
100
+ ##############################################################################
101
+ with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
102
+ with gr.Row(equal_height=True):
103
  with gr.Column():
104
+ face_in = gr.Image(type="pil", label="顔画像 (必須)")
105
+ subj_in = gr.Textbox(label="被写体説明", placeholder="例: 20代日本人女性")
106
+ add_in = gr.Textbox(label="追加プロンプト", placeholder="例: masterpiece, best quality, ...")
107
+ addneg_in = gr.Textbox(label="ネガティブ", value="(worst quality:2), lowres, bad hand, ...")
108
+ with gr.Row():
109
+ ip_sld = gr.Slider(0.0,1.0,0.6,step=0.05,label="IP Adapter Weight")
110
+ cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG")
111
  step_sld = gr.Slider(10,50,20,step=1,label="Steps")
112
+ w_sld = gr.Slider(512,1024,512,step=64,label="幅")
113
+ h_sld = gr.Slider(512,1024,768,step=64,label="高さ")
114
+ up_ck = gr.Checkbox(label="アップスケール",value=True)
115
+ up_fac = gr.Slider(1,8,2,step=1,label="倍率")
116
  btn = gr.Button("生成",variant="primary")
117
  with gr.Column():
118
  out_img = gr.Image(label="結果")
 
120
  # .queue() はGradioの通常機能として必要
121
  demo.queue()
122
 
123
+ def generate_ui(face_img, subj, add, addneg, cfg, ipw, steps, w, h, upscale, up_factor):
124
+ # 実際の推論関数(省略:ここに InstantID 推論処理を実装)
125
+ return face_img # ダミー
126
+
127
  btn.click(
128
  fn=generate_ui,
129
  inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac],
130
+ outputs=[out_img]
131
  )
132
 
133
+ ##############################################################################
134
+ # 4. FastAPI エンドポイント(REST API 用)
135
+ ##############################################################################
136
  app = FastAPI()
137
 
 
138
  @app.post("/api/predict")
139
+ async def predict(
140
+ face: UploadFile = File(...),
141
+ subject: str = Form(...),
142
  add_prompt: str = Form(""),
143
  add_neg: str = Form(""),
144
  cfg: float = Form(6.0),
145
+ ipw: float = Form(0.6),
146
  steps: int = Form(20),
147
  w: int = Form(512),
148
  h: int = Form(768),
149
  upscale: bool = Form(True),
150
+ up_factor: int = Form(2)
151
  ):
152
  try:
153
+ # 実際の推論ロジック(省略)
154
+ result_pil_image = Image.open(face.file) # ダミー
 
 
 
 
 
 
155
 
156
  buffered = io.BytesIO()
157
  result_pil_image.save(buffered, format="PNG")
 
166
  app = gr.mount_gradio_app(app, demo, path="/")
167
 
168
  print("Application startup script finished. Waiting for requests.")
 
169
 
170
+ #------------------------------------------------------------------------
171
+ # 5. Uvicorn サーバー起動(Spaces が呼び出すエントリポイント)
172
+ #------------------------------------------------------------------------
173
  if __name__ == "__main__":
174
+ import uvicorn, os
175
+ # Hugging Face Spaces が $PORT を渡してくる場合はそれを優先
176
+ port = int(os.getenv("PORT", 7860))
177
+ uvicorn.run(app, host="0.0.0.0", port=port, workers=1, log_level="info")