Spaces:
Running
on
Zero
Running
on
Zero
# app.py โ InstantID ร Beautiful Realistic Asians v7 (ZeroGPU-ready, FastAPI + Gradio) | |
# 2025-06-21 ็ | |
# | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# ไธปใช็นๅพด | |
# โข @spaces.GPU(duration=60) ใๅ ฌ้ๅ generate_core() ใซไปไธ | |
# โข ใใคใใฉใคใณใฏ lazy-load ใงๅๅๆจ่ซๆใซ GPU ใธใญใผใ | |
# โข ใขใใซ่ณ็ฃใฏ /data ใพใใฏ ~/.cache ใซๆฐธ็ถๅ | |
# โข Real-ESRGAN ใขใใในใฑใผใซ (x4 / x8) ใชใใทใงใณ | |
# โข Gradio UI + FastAPI REST ใ 1 ใใญใปในใงๅ ฑๅญ | |
# โข Uvicorn ๆๅ่ตทๅใฏไธ่ฆ๏ผSpaces ใ่ชๅใง็ซใฆใ๏ผ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
import os | |
import io | |
import base64 | |
import subprocess | |
import traceback | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import torch | |
import gradio as gr | |
import spaces | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from PIL import Image | |
from diffusers import ( | |
StableDiffusionControlNetPipeline, | |
ControlNetModel, | |
DPMSolverMultistepScheduler, | |
AutoencoderKL, | |
) | |
from diffusers.loaders import AttnProcsLayers | |
from insightface.app import FaceAnalysis | |
from basicsr.utils.download_util import load_file_from_url | |
from realesrgan import RealESRGANer | |
# ============================================================== | |
# 0. ใญใฃใใทใฅใใฃใฌใฏใใชใจใใฆใณใญใผใ | |
# ============================================================== | |
PERSIST_BASE = Path("/data") | |
CACHE_ROOT = ( | |
PERSIST_BASE / "instantid_cache" | |
if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK) | |
else Path.home() / ".cache" / "instantid_cache" | |
) | |
MODELS_DIR = CACHE_ROOT / "models" | |
LORA_DIR = CACHE_ROOT / "lora" | |
UPSCALE_DIR = CACHE_ROOT / "realesrgan" | |
for _p in (MODELS_DIR, LORA_DIR, UPSCALE_DIR): | |
_p.mkdir(parents=True, exist_ok=True) | |
def download(url: str, dst: Path, attempts: int = 2): | |
"""ๅ็ดใชใใฉใคไปใใใฆใณใญใผใ๏ผcurl or basicsr fallback๏ผ""" | |
if dst.exists(): | |
return dst | |
for i in range(1, attempts + 1): | |
try: | |
subprocess.check_call(["curl", "-L", "-o", str(dst), url]) | |
return dst | |
except subprocess.CalledProcessError: | |
print(f"[DL] Retry {i}/{attempts} failed: {url}") | |
# ๆๅพใซ basicsr ใฎใใฆใณใญใผใใงใใฉใผใซใใใฏ | |
load_file_from_url(url=url, model_dir=str(dst.parent), file_name=dst.name) | |
return dst | |
# ============================================================== | |
# 1. ใขใใซ URL ๅฎ็พฉ | |
# ============================================================== | |
BRA_V7_URL = ( | |
"https://huggingface.co/i0switch-assets/Beautiful_Realistic_Asians_v7/" | |
"resolve/main/beautiful_realistic_asians_v7_fp16.safetensors" | |
) | |
IP_ADAPTER_BIN_URL = ( | |
"https://huggingface.co/h94/IP-Adapter/resolve/main/ip-adapter-plus-face_sd15.bin" | |
) | |
IP_ADAPTER_LORA_URL = ( | |
"https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/" | |
"ip-adapter-faceid-plusv2_sd15_lora.safetensors" | |
) | |
REALESRGAN_URL = ( | |
"https://huggingface.co/aimagelab/realesrgan/resolve/main/RealESRGAN_x4plus.pth" | |
) | |
# ============================================================== | |
# 2. ใฐใญใผใใซๅคๆฐ๏ผlazy-load ใใใ๏ผ | |
# ============================================================== | |
pipe: Optional[StableDiffusionControlNetPipeline] = None | |
face_analyser: Optional[FaceAnalysis] = None | |
upsampler: Optional[RealESRGANer] = None | |
# ============================================================== | |
# 3. ใใคใใฉใคใณๅๆๅ | |
# ============================================================== | |
def initialize_pipelines(): | |
global pipe, face_analyser, upsampler | |
if pipe is not None: | |
return # ๆขใซๅๆๅๆธใฟ | |
print("[INIT] Downloading model assets โฆ") | |
# ---- 3-1. ๅบๆฌใขใใซ & IP-Adapter ---- | |
bra_ckpt = download(BRA_V7_URL, MODELS_DIR / "bra_v7.safetensors") | |
ip_bin = download(IP_ADAPTER_BIN_URL, MODELS_DIR / "ip_adapter.bin") | |
ip_lora = download(IP_ADAPTER_LORA_URL, LORA_DIR / "ip_adapter_faceid.lora") | |
# ---- 3-2. ControlNet (InstantID) ---- | |
controlnet = ControlNetModel.from_pretrained( | |
"InstantID/ControlNet-Mediapipe-Face", | |
torch_dtype=torch.float16, | |
cache_dir=str(MODELS_DIR), | |
) | |
# ---- 3-3. Diffusers ใใคใใฉใคใณ ---- | |
pipe_local_files_only = { | |
"controlnet": controlnet, | |
"vae": AutoencoderKL.from_pretrained( | |
"stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16 | |
), | |
"torch_dtype": torch.float16, | |
"safety_checker": None, | |
} | |
pipe_base = "runwayml/stable-diffusion-v1-5" | |
pipe_kwargs = dict( | |
local_files_only=False, | |
cache_dir=str(MODELS_DIR), | |
load_safety_checker=False, | |
) | |
pipe_tmp = StableDiffusionControlNetPipeline.from_pretrained( | |
pipe_base, **pipe_local_files_only, **pipe_kwargs | |
) | |
pipe_tmp.scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
pipe_base, subfolder="scheduler", cache_dir=str(MODELS_DIR) | |
) | |
# LoRA / IP-Adapter | |
pipe_tmp.load_ip_adapter( | |
ip_bin, | |
subfolder=None, | |
weight_name=None, | |
) | |
ip_layers = AttnProcsLayers(pipe_tmp.unet.attn_processors) | |
ip_layers.load_lora_weights(ip_lora, adapter_name="ip_faceid", safe_load=True) | |
pipe_tmp.set_adapters(["ip_faceid"], adapter_weights=[0.6]) | |
pipe_tmp.to("cuda") | |
pipe = pipe_tmp | |
# ---- 3-4. InsightFace ---- | |
face_analyser = FaceAnalysis( | |
name="buffalo_l", root=str(MODELS_DIR), providers=["CUDAExecutionProvider"] | |
) | |
face_analyser.prepare(ctx_id=0, det_size=(640, 640)) | |
# ---- 3-5. Real-ESRGAN ---- | |
esrgan_ckpt = download(REALESRGAN_URL, UPSCALE_DIR / "realesrgan_x4plus.pth") | |
upsampler = RealESRGANer( | |
scale=4, | |
model_path=str(esrgan_ckpt), | |
half=True, | |
tile=512, | |
tile_pad=10, | |
pre_pad=0, | |
gpu_id=0, | |
) | |
print("[INIT] Pipelines ready.") | |
# ============================================================== | |
# 4. ใใญใณใใ่จญๅฎ | |
# ============================================================== | |
BASE_PROMPT = ( | |
"(masterpiece:1.2), best quality, ultra-realistic, 8k, RAW photo, " | |
"cinematic lighting, textured skin, " | |
) | |
NEG_PROMPT = ( | |
"verybadimagenegative_v1.3, ng_deepnegative_v1_75t, " | |
"(worst quality:2), (low quality:2), lowres, blurry, bad anatomy, " | |
"bad hands, extra digits, cropped, watermark, signature" | |
) | |
# ============================================================== | |
# 5. ็ๆใณใข้ขๆฐ๏ผGPU ใๆดใ๏ผ | |
# ============================================================== | |
def generate_core( | |
face_img: Image.Image, | |
subject: str, | |
add_prompt: str = "", | |
add_neg: str = "", | |
cfg: float = 7.5, | |
ip_scale: float = 0.6, | |
steps: int = 30, | |
w: int = 768, | |
h: int = 768, | |
upscale: bool = False, | |
up_factor: int = 4, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), | |
): | |
""" | |
ใกใคใณ็ๆ้ขๆฐ | |
""" | |
try: | |
if pipe is None: | |
initialize_pipelines() | |
face_np = np.array(face_img) | |
face_info = face_analyser.get(face_np) | |
if len(face_info) == 0: | |
raise ValueError("้กใๆคๅบใงใใพใใใงใใใ") | |
pipe.set_adapters(["ip_faceid"], adapter_weights=[ip_scale]) | |
prompt = BASE_PROMPT + subject + ", " + add_prompt | |
negative = NEG_PROMPT + ", " + add_neg | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative, | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
image=face_img, | |
control_image=None, | |
width=int(w), | |
height=int(h), | |
).images[0] | |
if upscale and upsampler is not None: | |
scale = 4 if up_factor == 4 else 8 | |
upsampler.scale = scale | |
result, _ = upsampler.enhance(np.array(result)) | |
result = Image.fromarray(result) | |
return result | |
except Exception as e: | |
traceback.print_exc() | |
raise e | |
# ============================================================== | |
# 6. Gradio UI | |
# ============================================================== | |
with gr.Blocks(title="InstantID ร BRA v7 (ZeroGPU)") as demo: | |
gr.Markdown("## InstantID ร Beautiful Realistic Asians v7") | |
with gr.Row(): | |
face_img = gr.Image(type="pil", label="Face ID", sources=["upload"]) | |
subject = gr.Textbox( | |
label="่ขซๅไฝ่ชฌๆ๏ผไพ: '30ไปฃๆฅๆฌไบบๅฅณๆงใ้ป้ซชใปใใญใณใฐ'๏ผ", interactive=True | |
) | |
add_prompt = gr.Textbox(label="่ฟฝๅ ใใญใณใใ", interactive=True) | |
add_neg = gr.Textbox(label="่ฟฝๅ ใใฌใใฃใ", interactive=True) | |
with gr.Row(): | |
cfg = gr.Slider(1, 20, value=7.5, step=0.5, label="CFG Scale") | |
ip_scale = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IP-Adapter Weight") | |
with gr.Row(): | |
steps = gr.Slider(10, 50, value=30, step=1, label="Steps") | |
w = gr.Slider(512, 1024, value=768, step=64, label="Width") | |
h = gr.Slider(512, 1024, value=768, step=64, label="Height") | |
with gr.Row(): | |
upscale = gr.Checkbox(label="Real-ESRGAN Upscale", value=False) | |
up_factor = gr.Radio([4, 8], value=4, label="Upscale Factor") | |
run_btn = gr.Button("Generate") | |
output_img = gr.Image(type="pil", label="Result") | |
run_btn.click( | |
fn=generate_core, | |
inputs=[ | |
face_img, | |
subject, | |
add_prompt, | |
add_neg, | |
cfg, | |
ip_scale, | |
steps, | |
w, | |
h, | |
upscale, | |
up_factor, | |
], | |
outputs=output_img, | |
show_progress=True, | |
) | |
# ============================================================== | |
# 7. FastAPI ใจใณใใใคใณใ | |
# ============================================================== | |
app = FastAPI() | |
async def api_generate( | |
subject: str = Form(...), | |
cfg: float = Form(7.5), | |
steps: int = Form(30), | |
ip_scale: float = Form(0.6), | |
w: int = Form(768), | |
h: int = Form(768), | |
file: UploadFile = File(...), | |
): | |
try: | |
img_bytes = await file.read() | |
pil = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
res = generate_core( | |
face_img=pil, | |
subject=subject, | |
add_prompt="", | |
add_neg="", | |
cfg=cfg, | |
ip_scale=ip_scale, | |
steps=steps, | |
w=w, | |
h=h, | |
upscale=False, | |
up_factor=4, | |
) | |
buf = io.BytesIO() | |
res.save(buf, format="PNG") | |
b64 = base64.b64encode(buf.getvalue()).decode() | |
return {"image": f"data:image/png;base64,{b64}"} | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
# ============================================================== | |
# 8. Launch | |
# ============================================================== | |
# Spaces ใ่ชๅใง Uvicorn ใ่ตทๅใใใใใๆๅ่ตทๅใฏไธ่ฆใ | |
demo.queue(concurrency_count=2).launch(share=False) | |