i0switch commited on
Commit
e0bbeb9
·
verified ·
1 Parent(s): fc0ad0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # app.py — InstantID × Beautiful Realistic Asians v7 (ZeroGPU-friendly, persistent cache)
2
  """Persistent-cache backend for InstantID portrait generation.
3
- - Caches model assets under /data when writable, else ~/.cache
4
- - Robust download with retry + multiple fallback URLs per asset
5
  """
6
  import os, subprocess, cv2, torch, spaces, gradio as gr, numpy as np
7
  from pathlib import Path
@@ -13,22 +13,25 @@ from diffusers import (
13
  from insightface.app import FaceAnalysis
14
 
15
  ##############################################################################
16
- # 0. Cache dir & helpers
17
  ##############################################################################
18
  PERSIST_BASE = Path("/data")
19
- CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
20
- else Path.home() / ".cache" / "instantid_cache")
 
 
 
21
  print("cache →", CACHE_ROOT)
22
 
23
  MODELS_DIR = CACHE_ROOT / "models"
24
- LORA_DIR = MODELS_DIR / "Lora"
25
  EMB_DIR = CACHE_ROOT / "embeddings"
26
  UPSCALE_DIR = CACHE_ROOT / "realesrgan"
27
  for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
28
  p.mkdir(parents=True, exist_ok=True)
29
 
30
-
31
  def dl(url: str, dst: Path, attempts: int = 2):
 
32
  if dst.exists():
33
  print("✓", dst.relative_to(CACHE_ROOT)); return
34
  for i in range(1, attempts + 1):
@@ -38,22 +41,25 @@ def dl(url: str, dst: Path, attempts: int = 2):
38
  raise RuntimeError(f"download failed → {url}")
39
 
40
  ##############################################################################
41
- # 1. Asset download
42
  ##############################################################################
43
  print("— asset check —")
44
 
45
- # 1-A. base ckpt
46
  BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
47
- dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT)
48
-
49
- # 1-B. IP-Adapter core + FaceID LoRA
50
- IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin"
51
- dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE)
52
 
 
53
  LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
54
- dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE)
 
 
 
55
 
56
- # 1-C. textual-inversion embeddings
57
  EMB_URLS = {
58
  "ng_deepnegative_v1_75t.pt": [
59
  "https://huggingface.co/datasets/gsdf/EasyNegative/resolve/main/ng_deepnegative_v1_75t.pt",
@@ -81,7 +87,7 @@ for fname, urls in EMB_URLS.items():
81
  if idx == len(urls): raise
82
  print(" ↳ fallback URL …")
83
 
84
- # 1-D. Real-ESRGAN weights 8×
85
  RRG_WEIGHTS = UPSCALE_DIR / "RealESRGAN_x8plus.pth"
86
  RRG_URLS = [
87
  "https://huggingface.co/NoCrypt/Superscale_RealESRGAN/resolve/main/RealESRGAN_x8plus.pth",
@@ -96,38 +102,48 @@ for idx, link in enumerate(RRG_URLS, 1):
96
  print(" ↳ fallback URL …")
97
 
98
  ##############################################################################
99
- # 2. Runtime init
100
  ##############################################################################
101
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
103
  print("device:", device, "| dtype:", dtype)
104
 
105
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if torch.cuda.is_available() else ["CPUExecutionProvider"]
 
 
 
 
106
  face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
107
  face_app.prepare(ctx_id=(0 if torch.cuda.is_available() else -1), det_size=(640, 640))
108
 
109
- controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype)
110
- pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2)
111
- pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device)
 
 
 
 
 
 
 
112
  pipe.controlnet = controlnet
113
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
114
-
115
- # --- 修正ポイントここから --------------------------------------------------
116
- # 画像エンコーダは Lora/models/image_encoder/ に格納されている
117
- IMAGE_ENCODER_DIR = LORA_DIR / "models" / "image_encoder"
118
 
 
119
  pipe.load_ip_adapter(
120
- str(LORA_DIR), # ip_adapter.bin の親ディレクトリ
121
- subfolder="", # ip_adapter.bin は Lora/ 直下
122
- weight_name=IP_BIN_FILE.name, # LoRA 本体
123
- image_encoder_path=str(IMAGE_ENCODER_DIR) # 画像エンコーダの場所を明示
124
  )
125
- # --- 修正ポイントここまで --------------------------------------------------
126
 
127
- # FaceID LoRA (差分 LoRA)
128
  pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
129
  pipe.set_ip_adapter_scale(0.65)
130
 
 
131
  for emb in EMB_DIR.glob("*.*"):
132
  try:
133
  pipe.load_textual_inversion(emb, token=emb.stem)
@@ -138,7 +154,7 @@ pipe.to(device)
138
  print("pipeline ready ✔")
139
 
140
  ##############################################################################
141
- # 3. Upscaler
142
  ##############################################################################
143
  try:
144
  from basicsr.archs.rrdb_arch import RRDBNet
@@ -155,7 +171,7 @@ except Exception as e:
155
  UPSCALE_OK = False
156
 
157
  ##############################################################################
158
- # 4. Prompts & generation
159
  ##############################################################################
160
  BASE_PROMPT = (
161
  "(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n"
@@ -177,7 +193,7 @@ NEG_PROMPT = (
177
  @spaces.GPU(duration=90)
178
  def generate(
179
  face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor,
180
- progress=gr.Progress(track_tqdm=True)
181
  ):
182
  if face_np is None or face_np.size == 0:
183
  raise gr.Error("顔画像をアップロードしてください。")
@@ -204,11 +220,15 @@ def generate(
204
 
205
  if upscale:
206
  if UPSCALE_OK:
207
- up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor)
 
 
208
  result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
209
  else:
210
- result = result.resize((int(result.width * up_factor), int(result.height * up_factor)), Image.LANCZOS)
211
-
 
 
212
  return result
213
 
214
  ##############################################################################
 
1
  # app.py — InstantID × Beautiful Realistic Asians v7 (ZeroGPU-friendly, persistent cache)
2
  """Persistent-cache backend for InstantID portrait generation.
3
+ * 依存モデルは /data が書込可ならそこへ、それ以外は ~/.cache に保存
4
+ * wget を使った簡易リトライ DL
5
  """
6
  import os, subprocess, cv2, torch, spaces, gradio as gr, numpy as np
7
  from pathlib import Path
 
13
  from insightface.app import FaceAnalysis
14
 
15
  ##############################################################################
16
+ # 0. キャッシュ用ディレクトリ
17
  ##############################################################################
18
  PERSIST_BASE = Path("/data")
19
+ CACHE_ROOT = (
20
+ PERSIST_BASE / "instantid_cache"
21
+ if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
22
+ else Path.home() / ".cache" / "instantid_cache"
23
+ )
24
  print("cache →", CACHE_ROOT)
25
 
26
  MODELS_DIR = CACHE_ROOT / "models"
27
+ LORA_DIR = MODELS_DIR / "Lora" # FaceID LoRA などを置く
28
  EMB_DIR = CACHE_ROOT / "embeddings"
29
  UPSCALE_DIR = CACHE_ROOT / "realesrgan"
30
  for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
31
  p.mkdir(parents=True, exist_ok=True)
32
 
 
33
  def dl(url: str, dst: Path, attempts: int = 2):
34
+ """wget + リトライの簡易ダウンローダ"""
35
  if dst.exists():
36
  print("✓", dst.relative_to(CACHE_ROOT)); return
37
  for i in range(1, attempts + 1):
 
41
  raise RuntimeError(f"download failed → {url}")
42
 
43
  ##############################################################################
44
+ # 1. 必要アセットのダウンロード
45
  ##############################################################################
46
  print("— asset check —")
47
 
48
+ # 1-A. ベース checkpoint
49
  BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
50
+ dl(
51
+ "https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16",
52
+ BASE_CKPT,
53
+ )
 
54
 
55
+ # 1-B. FaceID LoRA(Δのみ)
56
  LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
57
+ dl(
58
+ "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors",
59
+ LORA_FILE,
60
+ )
61
 
62
+ # 1-C. textual inversion Embeddings
63
  EMB_URLS = {
64
  "ng_deepnegative_v1_75t.pt": [
65
  "https://huggingface.co/datasets/gsdf/EasyNegative/resolve/main/ng_deepnegative_v1_75t.pt",
 
87
  if idx == len(urls): raise
88
  print(" ↳ fallback URL …")
89
 
90
+ # 1-D. Real-ESRGAN weights 8)
91
  RRG_WEIGHTS = UPSCALE_DIR / "RealESRGAN_x8plus.pth"
92
  RRG_URLS = [
93
  "https://huggingface.co/NoCrypt/Superscale_RealESRGAN/resolve/main/RealESRGAN_x8plus.pth",
 
102
  print(" ↳ fallback URL …")
103
 
104
  ##############################################################################
105
+ # 2. ランタイム初期化
106
  ##############################################################################
107
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
109
  print("device:", device, "| dtype:", dtype)
110
 
111
+ providers = (
112
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
113
+ if torch.cuda.is_available()
114
+ else ["CPUExecutionProvider"]
115
+ )
116
  face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
117
  face_app.prepare(ctx_id=(0 if torch.cuda.is_available() else -1), det_size=(640, 640))
118
 
119
+ # ControlNet + SD パイプライン
120
+ controlnet = ControlNetModel.from_pretrained(
121
+ "InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype
122
+ )
123
+ pipe = StableDiffusionPipeline.from_single_file(
124
+ BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2
125
+ )
126
+ pipe.vae = AutoencoderKL.from_pretrained(
127
+ "stabilityai/sd-vae-ft-mse", torch_dtype=dtype
128
+ ).to(device)
129
  pipe.controlnet = controlnet
130
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
131
+ pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
132
+ )
 
 
133
 
134
+ # --- ここが核心:画像エンコーダ込みで公式レポから直接ロード ------------------
135
  pipe.load_ip_adapter(
136
+ "h94/IP-Adapter", # Hugging Face Hub ID
137
+ subfolder="models", # ip-adapter-plus-face_sd15.bin が入っているフォルダ
138
+ weight_name="ip-adapter-plus-face_sd15.bin",
 
139
  )
140
+ # ---------------------------------------------------------------------------
141
 
142
+ # FaceID LoRA(差分 LoRA のみ)
143
  pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
144
  pipe.set_ip_adapter_scale(0.65)
145
 
146
+ # textual inversion 読み込み
147
  for emb in EMB_DIR.glob("*.*"):
148
  try:
149
  pipe.load_textual_inversion(emb, token=emb.stem)
 
154
  print("pipeline ready ✔")
155
 
156
  ##############################################################################
157
+ # 3. アップスケーラ
158
  ##############################################################################
159
  try:
160
  from basicsr.archs.rrdb_arch import RRDBNet
 
171
  UPSCALE_OK = False
172
 
173
  ##############################################################################
174
+ # 4. プロンプト & 生成関数
175
  ##############################################################################
176
  BASE_PROMPT = (
177
  "(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n"
 
193
  @spaces.GPU(duration=90)
194
  def generate(
195
  face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor,
196
+ progress=gr.Progress(track_tqdm=True),
197
  ):
198
  if face_np is None or face_np.size == 0:
199
  raise gr.Error("顔画像をアップロードしてください。")
 
220
 
221
  if upscale:
222
  if UPSCALE_OK:
223
+ up, _ = upsampler.enhance(
224
+ cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor
225
+ )
226
  result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
227
  else:
228
+ result = result.resize(
229
+ (int(result.width * up_factor), int(result.height * up_factor)),
230
+ Image.LANCZOS,
231
+ )
232
  return result
233
 
234
  ##############################################################################