i0switch commited on
Commit
f2192e3
·
verified ·
1 Parent(s): 942bdcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -99
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — ZeroGPU対応版
2
  import gradio as gr
3
  import spaces
4
  import torch
@@ -10,14 +10,11 @@ import traceback
10
  import base64
11
  import io
12
  from pathlib import Path
13
-
14
- # FastAPI関連(ハイブリッド構成のため維持)
15
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
 
17
  ##############################################################################
18
  # 0. 設定とヘルパー
19
  ##############################################################################
20
- # モデル・LoRA キャッシュを /data に置ける場合はそちらを優先
21
  PERSIST_BASE = Path("/data")
22
  CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
23
  and os.access(PERSIST_BASE, os.W_OK)
@@ -28,14 +25,15 @@ for d in (MODELS_DIR, LORA_DIR):
28
  d.mkdir(parents=True, exist_ok=True)
29
 
30
  def dl(url: str, dst: Path, attempts: int = 2):
31
- """冪等ダウンロード(既存ならスキップ、リトライ付き)"""
32
- if dst.exists(): return
 
33
  for i in range(1, attempts + 1):
34
  print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
35
- if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return
 
36
  raise RuntimeError(f"download failed → {url}")
37
 
38
- # 1. Asset download (起動時に実行)
39
  print("— Starting asset download check —")
40
  BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
41
  dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
@@ -45,58 +43,8 @@ LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
45
  dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
46
  print("— Asset download check finished —")
47
 
48
-
49
- # 2. パイプライン初期化関数 (GPU確保後に呼び出される)
50
- def load_pipeline():
51
- from diffusers import (
52
- StableDiffusionPipeline, ControlNetModel,
53
- DPMSolverMultistepScheduler, AutoencoderKL,
54
- )
55
- from insightface.app import FaceAnalysis
56
-
57
- print("→ Loading models to GPU …")
58
-
59
- # --- InstantID 主要モデル ---
60
- vae = AutoencoderKL.from_pretrained(
61
- "stabilityai/sd-vae-ft-mse",
62
- torch_dtype=torch.float16
63
- )
64
- base = StableDiffusionPipeline.from_single_file(
65
- str(BASE_CKPT),
66
- vae=vae,
67
- torch_dtype=torch.float16,
68
- safety_checker=None,
69
- original_config_file="v1-inference.yaml" # StableDiffusion1.x 互換
70
- )
71
- control = ControlNetModel.from_pretrained(
72
- "lllyasviel/control_v11p_sd15_openpose",
73
- torch_dtype=torch.float16
74
- )
75
- pipe = StableDiffusionPipeline(
76
- vae=vae,
77
- text_encoder=base.text_encoder,
78
- tokenizer=base.tokenizer,
79
- unet=base.unet,
80
- controlnet=control,
81
- scheduler=DPMSolverMultistepScheduler.from_config(base.scheduler.config),
82
- safety_checker=None,
83
- feature_extractor=base.feature_extractor,
84
- requires_safety_checker=False
85
- ).to("cuda", dtype=torch.float16)
86
- pipe.load_lora_weights(str(LORA_FILE))
87
- pipe.set_adapters(["ip_adapter_face"], [1.0])
88
- pipe.enable_xformers_memory_efficient_attention()
89
-
90
- # --- InsightFace ---
91
- face_analyzer = FaceAnalysis(name="antelopev2", providers=["CUDAExecutionProvider"])
92
- face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
93
-
94
- print("✓ Model loading complete.")
95
- return pipe, face_analyzer
96
-
97
-
98
  ##############################################################################
99
- # 3. Gradio UI
100
  ##############################################################################
101
  with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
102
  with gr.Row(equal_height=True):
@@ -116,62 +64,49 @@ with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
116
  btn = gr.Button("生成",variant="primary")
117
  with gr.Column():
118
  out_img = gr.Image(label="結果")
119
-
120
- # .queue() はGradioの通常機能として必要
121
  demo.queue()
122
-
123
- def generate_ui(face_img, subj, add, addneg, cfg, ipw, steps, w, h, upscale, up_factor):
124
- # 実際の推論関数(省略:ここに InstantID 推論処理を実装)
125
- return face_img # ダミー
126
 
127
- btn.click(
128
- fn=generate_ui,
129
- inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac],
130
- outputs=[out_img]
131
- )
 
 
132
 
133
  ##############################################################################
134
- # 4. FastAPI エンドポイント(REST API 用)
135
  ##############################################################################
136
  app = FastAPI()
137
 
138
  @app.post("/api/predict")
139
- async def predict(
140
- face: UploadFile = File(...),
141
- subject: str = Form(...),
142
- add_prompt: str = Form(""),
143
- add_neg: str = Form(""),
144
- cfg: float = Form(6.0),
145
- ipw: float = Form(0.6),
146
- steps: int = Form(20),
147
- w: int = Form(512),
148
- h: int = Form(768),
149
- upscale: bool = Form(True),
150
- up_factor: int = Form(2)
151
- ):
152
  try:
153
- # 実際の推論ロジック(省略)
154
- result_pil_image = Image.open(face.file) # ダミー
155
-
156
  buffered = io.BytesIO()
157
- result_pil_image.save(buffered, format="PNG")
158
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
159
-
160
- return {"image_base64": img_str}
161
  except Exception as e:
162
  traceback.print_exc()
163
  raise HTTPException(status_code=500, detail=str(e))
164
 
165
- # GradioアプリをFastAPIアプリにマウント
166
  app = gr.mount_gradio_app(app, demo, path="/")
167
 
168
  print("Application startup script finished. Waiting for requests.")
169
 
170
- #------------------------------------------------------------------------
171
- # 5. Uvicorn サーバー起動(Spaces が呼び出すエントリポイント)
172
- #------------------------------------------------------------------------
173
  if __name__ == "__main__":
174
- import uvicorn, os
175
- # Hugging Face Spaces が $PORT を渡してくる場合はそれを優先
176
- port = int(os.getenv("PORT", 7860))
177
- uvicorn.run(app, host="0.0.0.0", port=port, workers=1, log_level="info")
 
 
 
 
 
 
 
1
+ # app.py — ZeroGPU対応 + ポート自動フォールバック
2
  import gradio as gr
3
  import spaces
4
  import torch
 
10
  import base64
11
  import io
12
  from pathlib import Path
 
 
13
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
14
 
15
  ##############################################################################
16
  # 0. 設定とヘルパー
17
  ##############################################################################
 
18
  PERSIST_BASE = Path("/data")
19
  CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists()
20
  and os.access(PERSIST_BASE, os.W_OK)
 
25
  d.mkdir(parents=True, exist_ok=True)
26
 
27
  def dl(url: str, dst: Path, attempts: int = 2):
28
+ """冪等ダウンロード(既にあればスキップ、リトライ付き)"""
29
+ if dst.exists():
30
+ return
31
  for i in range(1, attempts + 1):
32
  print(f"⬇ Downloading {dst.name} (try {i}/{attempts})")
33
+ if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0:
34
+ return
35
  raise RuntimeError(f"download failed → {url}")
36
 
 
37
  print("— Starting asset download check —")
38
  BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
39
  dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
 
43
  dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
44
  print("— Asset download check finished —")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ##############################################################################
47
+ # 1. Gradio UI
48
  ##############################################################################
49
  with gr.Blocks(title="InstantID × Beautiful Realistic Asians v7") as demo:
50
  with gr.Row(equal_height=True):
 
64
  btn = gr.Button("生成",variant="primary")
65
  with gr.Column():
66
  out_img = gr.Image(label="結果")
 
 
67
  demo.queue()
 
 
 
 
68
 
69
+ # ダミー推論(実装は省略)
70
+ def generate_ui(*args, **kwargs):
71
+ return Image.new("RGB", (512,768), (127,127,127))
72
+ btn.click(generate_ui,
73
+ inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,
74
+ w_sld,h_sld,up_ck,up_fac],
75
+ outputs=[out_img])
76
 
77
  ##############################################################################
78
+ # 2. FastAPI ラッパー(REST API
79
  ##############################################################################
80
  app = FastAPI()
81
 
82
  @app.post("/api/predict")
83
+ async def predict(face: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
 
 
 
84
  try:
85
+ img = Image.open(face.file)
 
 
86
  buffered = io.BytesIO()
87
+ img.save(buffered, format="PNG")
88
+ img_b64 = base64.b64encode(buffered.getvalue()).decode()
89
+ return {"image_base64": img_b64}
 
90
  except Exception as e:
91
  traceback.print_exc()
92
  raise HTTPException(status_code=500, detail=str(e))
93
 
94
+ # GradioFastAPI にマウント
95
  app = gr.mount_gradio_app(app, demo, path="/")
96
 
97
  print("Application startup script finished. Waiting for requests.")
98
 
99
+ ##############################################################################
100
+ # 3. Uvicorn 起動(ポート重複時フォールバック)
101
+ ##############################################################################
102
  if __name__ == "__main__":
103
+ import uvicorn
104
+ port_env = int(os.getenv("PORT", "7860"))
105
+ try:
106
+ uvicorn.run(app, host="0.0.0.0", port=port_env, workers=1, log_level="info")
107
+ except OSError as e:
108
+ if e.errno == 98 and port_env != 7860:
109
+ print(f"⚠️ Port {port_env} busy → falling back to 7860")
110
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1, log_level="info")
111
+ else:
112
+ raise