i0switch's picture
Update app.py
45eb86f verified
raw
history blame
11.5 kB
# app.py โ€” InstantID ร— Beautiful Realistic Asians v7 (ZeroGPU-ready, FastAPI + Gradio)
# 2025-06-21 ็‰ˆ
#
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ไธปใช็‰นๅพด
# โ€ข @spaces.GPU(duration=60) ใ‚’ๅ…ฌ้–‹ๅ generate_core() ใซไป˜ไธŽ
# โ€ข ใƒ‘ใ‚คใƒ—ใƒฉใ‚คใƒณใฏ lazy-load ใงๅˆๅ›žๆŽจ่ซ–ๆ™‚ใซ GPU ใธใƒญใƒผใƒ‰
# โ€ข ใƒขใƒ‡ใƒซ่ณ‡็”ฃใฏ /data ใพใŸใฏ ~/.cache ใซๆฐธ็ถšๅŒ–
# โ€ข Real-ESRGAN ใ‚ขใƒƒใƒ—ใ‚นใ‚ฑใƒผใƒซ (x4 / x8) ใ‚ชใƒ—ใ‚ทใƒงใƒณ
# โ€ข Gradio UI + FastAPI REST ใ‚’ 1 ใƒ—ใƒญใ‚ปใ‚นใงๅ…ฑๅญ˜
# โ€ข Uvicorn ๆ‰‹ๅ‹•่ตทๅ‹•ใฏไธ่ฆ๏ผˆSpaces ใŒ่‡ชๅ‰ใง็ซ‹ใฆใ‚‹๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
import os
import io
import base64
import subprocess
import traceback
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import gradio as gr
import spaces
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from PIL import Image
from diffusers import (
StableDiffusionControlNetPipeline,
ControlNetModel,
DPMSolverMultistepScheduler,
AutoencoderKL,
)
from diffusers.loaders import AttnProcsLayers
from insightface.app import FaceAnalysis
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
# ==============================================================
# 0. ใ‚ญใƒฃใƒƒใ‚ทใƒฅใƒ‡ใ‚ฃใƒฌใ‚ฏใƒˆใƒชใจใƒ€ใ‚ฆใƒณใƒญใƒผใƒ€
# ==============================================================
PERSIST_BASE = Path("/data")
CACHE_ROOT = (
PERSIST_BASE / "instantid_cache"
if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
else Path.home() / ".cache" / "instantid_cache"
)
MODELS_DIR = CACHE_ROOT / "models"
LORA_DIR = CACHE_ROOT / "lora"
UPSCALE_DIR = CACHE_ROOT / "realesrgan"
for _p in (MODELS_DIR, LORA_DIR, UPSCALE_DIR):
_p.mkdir(parents=True, exist_ok=True)
def download(url: str, dst: Path, attempts: int = 2):
"""ๅ˜็ด”ใƒชใƒˆใƒฉใ‚คไป˜ใใƒ€ใ‚ฆใƒณใƒญใƒผใƒ€๏ผˆcurl or basicsr fallback๏ผ‰"""
if dst.exists():
return dst
for i in range(1, attempts + 1):
try:
subprocess.check_call(["curl", "-L", "-o", str(dst), url])
return dst
except subprocess.CalledProcessError:
print(f"[DL] Retry {i}/{attempts} failed: {url}")
# ๆœ€ๅพŒใซ basicsr ใฎใƒ€ใ‚ฆใƒณใƒญใƒผใƒ€ใงใƒ•ใ‚ฉใƒผใƒซใƒใƒƒใ‚ฏ
load_file_from_url(url=url, model_dir=str(dst.parent), file_name=dst.name)
return dst
# ==============================================================
# 1. ใƒขใƒ‡ใƒซ URL ๅฎš็พฉ
# ==============================================================
BRA_V7_URL = (
"https://huggingface.co/i0switch-assets/Beautiful_Realistic_Asians_v7/"
"resolve/main/beautiful_realistic_asians_v7_fp16.safetensors"
)
IP_ADAPTER_BIN_URL = (
"https://huggingface.co/h94/IP-Adapter/resolve/main/ip-adapter-plus-face_sd15.bin"
)
IP_ADAPTER_LORA_URL = (
"https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/"
"ip-adapter-faceid-plusv2_sd15_lora.safetensors"
)
REALESRGAN_URL = (
"https://huggingface.co/aimagelab/realesrgan/resolve/main/RealESRGAN_x4plus.pth"
)
# ==============================================================
# 2. ใ‚ฐใƒญใƒผใƒใƒซๅค‰ๆ•ฐ๏ผˆlazy-load ใ•ใ‚Œใ‚‹๏ผ‰
# ==============================================================
pipe: Optional[StableDiffusionControlNetPipeline] = None
face_analyser: Optional[FaceAnalysis] = None
upsampler: Optional[RealESRGANer] = None
# ==============================================================
# 3. ใƒ‘ใ‚คใƒ—ใƒฉใ‚คใƒณๅˆๆœŸๅŒ–
# ==============================================================
def initialize_pipelines():
global pipe, face_analyser, upsampler
if pipe is not None:
return # ๆ—ขใซๅˆๆœŸๅŒ–ๆธˆใฟ
print("[INIT] Downloading model assets โ€ฆ")
# ---- 3-1. ๅŸบๆœฌใƒขใƒ‡ใƒซ & IP-Adapter ----
bra_ckpt = download(BRA_V7_URL, MODELS_DIR / "bra_v7.safetensors")
ip_bin = download(IP_ADAPTER_BIN_URL, MODELS_DIR / "ip_adapter.bin")
ip_lora = download(IP_ADAPTER_LORA_URL, LORA_DIR / "ip_adapter_faceid.lora")
# ---- 3-2. ControlNet (InstantID) ----
controlnet = ControlNetModel.from_pretrained(
"InstantID/ControlNet-Mediapipe-Face",
torch_dtype=torch.float16,
cache_dir=str(MODELS_DIR),
)
# ---- 3-3. Diffusers ใƒ‘ใ‚คใƒ—ใƒฉใ‚คใƒณ ----
pipe_local_files_only = {
"controlnet": controlnet,
"vae": AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16
),
"torch_dtype": torch.float16,
"safety_checker": None,
}
pipe_base = "runwayml/stable-diffusion-v1-5"
pipe_kwargs = dict(
local_files_only=False,
cache_dir=str(MODELS_DIR),
load_safety_checker=False,
)
pipe_tmp = StableDiffusionControlNetPipeline.from_pretrained(
pipe_base, **pipe_local_files_only, **pipe_kwargs
)
pipe_tmp.scheduler = DPMSolverMultistepScheduler.from_pretrained(
pipe_base, subfolder="scheduler", cache_dir=str(MODELS_DIR)
)
# LoRA / IP-Adapter
pipe_tmp.load_ip_adapter(
ip_bin,
subfolder=None,
weight_name=None,
)
ip_layers = AttnProcsLayers(pipe_tmp.unet.attn_processors)
ip_layers.load_lora_weights(ip_lora, adapter_name="ip_faceid", safe_load=True)
pipe_tmp.set_adapters(["ip_faceid"], adapter_weights=[0.6])
pipe_tmp.to("cuda")
pipe = pipe_tmp
# ---- 3-4. InsightFace ----
face_analyser = FaceAnalysis(
name="buffalo_l", root=str(MODELS_DIR), providers=["CUDAExecutionProvider"]
)
face_analyser.prepare(ctx_id=0, det_size=(640, 640))
# ---- 3-5. Real-ESRGAN ----
esrgan_ckpt = download(REALESRGAN_URL, UPSCALE_DIR / "realesrgan_x4plus.pth")
upsampler = RealESRGANer(
scale=4,
model_path=str(esrgan_ckpt),
half=True,
tile=512,
tile_pad=10,
pre_pad=0,
gpu_id=0,
)
print("[INIT] Pipelines ready.")
# ==============================================================
# 4. ใƒ—ใƒญใƒณใƒ—ใƒˆ่จญๅฎš
# ==============================================================
BASE_PROMPT = (
"(masterpiece:1.2), best quality, ultra-realistic, 8k, RAW photo, "
"cinematic lighting, textured skin, "
)
NEG_PROMPT = (
"verybadimagenegative_v1.3, ng_deepnegative_v1_75t, "
"(worst quality:2), (low quality:2), lowres, blurry, bad anatomy, "
"bad hands, extra digits, cropped, watermark, signature"
)
# ==============================================================
# 5. ็”Ÿๆˆใ‚ณใ‚ข้–ขๆ•ฐ๏ผˆGPU ใ‚’ๆŽดใ‚€๏ผ‰
# ==============================================================
@spaces.GPU(duration=60)
def generate_core(
face_img: Image.Image,
subject: str,
add_prompt: str = "",
add_neg: str = "",
cfg: float = 7.5,
ip_scale: float = 0.6,
steps: int = 30,
w: int = 768,
h: int = 768,
upscale: bool = False,
up_factor: int = 4,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
"""
ใƒกใ‚คใƒณ็”Ÿๆˆ้–ขๆ•ฐ
"""
try:
if pipe is None:
initialize_pipelines()
face_np = np.array(face_img)
face_info = face_analyser.get(face_np)
if len(face_info) == 0:
raise ValueError("้ก”ใŒๆคœๅ‡บใงใใพใ›ใ‚“ใงใ—ใŸใ€‚")
pipe.set_adapters(["ip_faceid"], adapter_weights=[ip_scale])
prompt = BASE_PROMPT + subject + ", " + add_prompt
negative = NEG_PROMPT + ", " + add_neg
result = pipe(
prompt=prompt,
negative_prompt=negative,
num_inference_steps=int(steps),
guidance_scale=float(cfg),
image=face_img,
control_image=None,
width=int(w),
height=int(h),
).images[0]
if upscale and upsampler is not None:
scale = 4 if up_factor == 4 else 8
upsampler.scale = scale
result, _ = upsampler.enhance(np.array(result))
result = Image.fromarray(result)
return result
except Exception as e:
traceback.print_exc()
raise e
# ==============================================================
# 6. Gradio UI
# ==============================================================
with gr.Blocks(title="InstantID ร— BRA v7 (ZeroGPU)") as demo:
gr.Markdown("## InstantID ร— Beautiful Realistic Asians v7")
with gr.Row():
face_img = gr.Image(type="pil", label="Face ID", sources=["upload"])
subject = gr.Textbox(
label="่ขซๅ†™ไฝ“่ชฌๆ˜Ž๏ผˆไพ‹: '30ไปฃๆ—ฅๆœฌไบบๅฅณๆ€งใ€้ป’้ซชใ‚ปใƒŸใƒญใƒณใ‚ฐ'๏ผ‰", interactive=True
)
add_prompt = gr.Textbox(label="่ฟฝๅŠ ใƒ—ใƒญใƒณใƒ—ใƒˆ", interactive=True)
add_neg = gr.Textbox(label="่ฟฝๅŠ ใƒใ‚ฌใƒ†ใ‚ฃใƒ–", interactive=True)
with gr.Row():
cfg = gr.Slider(1, 20, value=7.5, step=0.5, label="CFG Scale")
ip_scale = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IP-Adapter Weight")
with gr.Row():
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
w = gr.Slider(512, 1024, value=768, step=64, label="Width")
h = gr.Slider(512, 1024, value=768, step=64, label="Height")
with gr.Row():
upscale = gr.Checkbox(label="Real-ESRGAN Upscale", value=False)
up_factor = gr.Radio([4, 8], value=4, label="Upscale Factor")
run_btn = gr.Button("Generate")
output_img = gr.Image(type="pil", label="Result")
run_btn.click(
fn=generate_core,
inputs=[
face_img,
subject,
add_prompt,
add_neg,
cfg,
ip_scale,
steps,
w,
h,
upscale,
up_factor,
],
outputs=output_img,
show_progress=True,
)
# ==============================================================
# 7. FastAPI ใ‚จใƒณใƒ‰ใƒใ‚คใƒณใƒˆ
# ==============================================================
app = FastAPI()
@app.post("/api/generate")
async def api_generate(
subject: str = Form(...),
cfg: float = Form(7.5),
steps: int = Form(30),
ip_scale: float = Form(0.6),
w: int = Form(768),
h: int = Form(768),
file: UploadFile = File(...),
):
try:
img_bytes = await file.read()
pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
res = generate_core(
face_img=pil,
subject=subject,
add_prompt="",
add_neg="",
cfg=cfg,
ip_scale=ip_scale,
steps=steps,
w=w,
h=h,
upscale=False,
up_factor=4,
)
buf = io.BytesIO()
res.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return {"image": f"data:image/png;base64,{b64}"}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# ==============================================================
# 8. Launch
# ==============================================================
# Spaces ใŒ่‡ชๅ‹•ใง Uvicorn ใ‚’่ตทๅ‹•ใ™ใ‚‹ใŸใ‚ใ€ๆ‰‹ๅ‹•่ตทๅ‹•ใฏไธ่ฆใ€‚
demo.queue(concurrency_count=2).launch(share=False)