seawolf2357's picture
Update app.py
2608108 verified
raw
history blame
18.3 kB
# ===========================================
# IP-Composer 🌅✚🖌️ – FULL IMPROVED UI SCRIPT
# (기능 동일, UI·테마·갤러리 강화 + FileNotFoundError 수정)
# ===========================================
import os, json, random, gc
import numpy as np
import torch
from PIL import Image
import gradio as gr
from gradio.themes import Soft
from diffusers import StableDiffusionXLPipeline
import open_clip
from huggingface_hub import hf_hub_download
from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
from IP_Composer.perform_swap import (
compute_dataset_embeds_svd,
get_modified_images_embeds_composition,
)
from IP_Composer.generate_text_embeddings import (
load_descriptions,
generate_embeddings,
)
import spaces
# ─────────────────────────────
# 1 · Device
# ─────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
# ─────────────────────────────
# 2 · Stable-Diffusion XL
# ─────────────────────────────
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
)
# ─────────────────────────────
# 3 · IP-Adapter
# ─────────────────────────────
image_encoder_repo = "h94/IP-Adapter"
image_encoder_subfolder = "models/image_encoder"
ip_ckpt = hf_hub_download(
"h94/IP-Adapter", subfolder="sdxl_models", filename="ip-adapter_sdxl_vit-h.bin"
)
ip_model = IPAdapterXL(
pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device
)
# ─────────────────────────────
# 4 · CLIP
# ─────────────────────────────
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
)
clip_model.to(device)
tokenizer = open_clip.get_tokenizer(
"hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
)
# ─────────────────────────────
# 5 · Concept maps
# ─────────────────────────────
CONCEPTS_MAP = {
"age": "age_descriptions.npy",
"animal fur": "fur_descriptions.npy",
"dogs": "dog_descriptions.npy",
"emotions": "emotion_descriptions.npy",
"flowers": "flower_descriptions.npy",
"fruit/vegtable": "fruit_vegetable_descriptions.npy",
"outfit type": "outfit_descriptions.npy",
"outfit pattern (including color)": "outfit_pattern_descriptions.npy",
"patterns": "pattern_descriptions.npy",
"patterns (including color)": "pattern_descriptions_with_colors.npy",
"vehicle": "vehicle_descriptions.npy",
"daytime": "times_of_day_descriptions.npy",
"pose": "person_poses_descriptions.npy",
"season": "season_descriptions.npy",
"material": "material_descriptions_with_gems.npy",
}
RANKS_MAP = {
"age": 30,
"animal fur": 80,
"dogs": 30,
"emotions": 30,
"flowers": 30,
"fruit/vegtable": 30,
"outfit type": 30,
"outfit pattern (including color)": 80,
"patterns": 80,
"patterns (including color)": 80,
"vehicle": 30,
"daytime": 30,
"pose": 30,
"season": 30,
"material": 80,
}
concept_options = list(CONCEPTS_MAP.keys())
# ─────────────────────────────
# 6 · Example tuples (base_img, c1_img, …)
# ─────────────────────────────
examples = [
[
"./IP_Composer/assets/patterns/base.jpg",
"./IP_Composer/assets/patterns/pattern.png",
"patterns (including color)",
None,
None,
None,
None,
80,
30,
30,
None,
1.0,
0,
30,
],
[
"./IP_Composer/assets/flowers/base.png",
"./IP_Composer/assets/flowers/concept.png",
"flowers",
None,
None,
None,
None,
30,
30,
30,
None,
1.0,
0,
30,
],
[
"./IP_Composer/assets/materials/base.png",
"./IP_Composer/assets/materials/concept.jpg",
"material",
None,
None,
None,
None,
80,
30,
30,
None,
1.0,
0,
30,
],
]
# ----------------------------------------------------------
# 7 · Utility functions
# ----------------------------------------------------------
def generate_examples(
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
):
return process_and_display(
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
)
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
return random.randint(0, MAX_SEED) if randomize_seed else seed
def change_rank_default(concept_name):
return RANKS_MAP.get(concept_name, 30)
@spaces.GPU
def match_image_to_concept(image):
if image is None:
return None
img_pil = Image.fromarray(image).convert("RGB")
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
sims = {}
for cname, cfile in CONCEPTS_MAP.items():
try:
with open(f"./IP_Composer/text_embeddings/{cfile}", "rb") as f:
embeds = np.load(f)
scores = []
for e in embeds:
s = np.dot(
img_embed.flatten() / np.linalg.norm(img_embed),
e.flatten() / np.linalg.norm(e),
)
scores.append(s)
scores.sort(reverse=True)
sims[cname] = np.mean(scores[:5])
except Exception as e:
print(cname, "error:", e)
if sims:
best = max(sims, key=sims.get)
gr.Info(f"Image automatically matched to concept: {best}")
return best
return None
@spaces.GPU
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
image = preproc(pil_image)[np.newaxis, :, :, :]
with torch.no_grad():
embeds = model.encode_image(image.to(dev))
return embeds.cpu().detach().numpy()
@spaces.GPU
def process_images(
base_image,
concept_image1,
concept_name1,
concept_image2=None,
concept_name2=None,
concept_image3=None,
concept_name3=None,
rank1=10,
rank2=10,
rank3=10,
prompt=None,
scale=1.0,
seed=420,
num_inference_steps=50,
concpet_from_file_1=None,
concpet_from_file_2=None,
concpet_from_file_3=None,
use_concpet_from_file_1=False,
use_concpet_from_file_2=False,
use_concpet_from_file_3=False,
):
base_pil = Image.fromarray(base_image).convert("RGB")
base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
concept_images, concept_descs, ranks = [], [], []
skip = [False, False, False]
# concept 1
if concept_image1 is None:
return None, "Please upload at least one concept image"
concept_images.append(concept_image1)
if use_concpet_from_file_1 and concpet_from_file_1 is not None:
concept_descs.append(concpet_from_file_1)
skip[0] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name1])
ranks.append(rank1)
# concept 2
if concept_image2 is not None:
concept_images.append(concept_image2)
if use_concpet_from_file_2 and concpet_from_file_2 is not None:
concept_descs.append(concpet_from_file_2)
skip[1] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name2])
ranks.append(rank2)
# concept 3
if concept_image3 is not None:
concept_images.append(concept_image3)
if use_concpet_from_file_3 and concpet_from_file_3 is not None:
concept_descs.append(concpet_from_file_3)
skip[2] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name3])
ranks.append(rank3)
concept_embeds, proj_mats = [], []
for i, concept in enumerate(concept_descs):
img_pil = Image.fromarray(concept_images[i]).convert("RGB")
concept_embeds.append(get_image_embeds(img_pil, clip_model, preprocess, device))
if skip[i]:
all_embeds = concept
else:
with open(f"./IP_Composer/text_embeddings/{concept}", "rb") as f:
all_embeds = np.load(f)
proj_mats.append(compute_dataset_embeds_svd(all_embeds, ranks[i]))
projections_data = [
{"embed": e, "projection_matrix": p}
for e, p in zip(concept_embeds, proj_mats)
]
modified = get_modified_images_embeds_composition(
base_embed,
projections_data,
ip_model,
prompt=prompt,
scale=scale,
num_samples=1,
seed=seed,
num_inference_steps=num_inference_steps,
)
return modified[0]
@spaces.GPU
def get_text_embeddings(concept_file):
descs = load_descriptions(concept_file)
embeds = generate_embeddings(descs, clip_model, tokenizer, device, batch_size=100)
return embeds, True
def process_and_display(
base_image,
concept_image1,
concept_name1="age",
concept_image2=None,
concept_name2=None,
concept_image3=None,
concept_name3=None,
rank1=30,
rank2=30,
rank3=30,
prompt=None,
scale=1.0,
seed=0,
num_inference_steps=50,
concpet_from_file_1=None,
concpet_from_file_2=None,
concpet_from_file_3=None,
use_concpet_from_file_1=False,
use_concpet_from_file_2=False,
use_concpet_from_file_3=False,
):
if base_image is None:
raise gr.Error("Please upload a base image")
if concept_image1 is None:
raise gr.Error("Choose at least one concept image")
return process_images(
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
concpet_from_file_1,
concpet_from_file_2,
concpet_from_file_3,
use_concpet_from_file_1,
use_concpet_from_file_2,
use_concpet_from_file_3,
)
# ----------------------------------------------------------
# 8 · THEME & CSS
# ----------------------------------------------------------
demo_theme = Soft(primary_hue="purple", font=[gr.themes.GoogleFont("Inter")])
css = """
body{
background:#0f0c29;
background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
}
#header{
text-align:center;
padding:24px 0 8px;
font-weight:700;
font-size:2.1rem;
color:#ffffff;
}
.gradio-container{max-width:1024px !important;margin:0 auto}
.card{
border-radius:18px;
background:#ffffff0d;
padding:18px 22px;
backdrop-filter:blur(6px);
}
.gr-image,.gr-video{border-radius:14px}
.gr-image:hover{box-shadow:0 0 0 4px #a855f7}
"""
# ----------------------------------------------------------
# 9 · UI
# ----------------------------------------------------------
example_gallery = [
["./IP_Composer/assets/patterns/base.jpg", "Patterns demo"],
["./IP_Composer/assets/flowers/base.png", "Flowers demo"],
["./IP_Composer/assets/materials/base.png", "Material demo"],
]
with gr.Blocks(css=css, theme=demo_theme) as demo:
gr.Markdown(
"<div id='header'>🌅 IP-Composer&nbsp;"
"<sup style='font-size:14px'>SDXL</sup></div>"
)
concpet_from_file_1, concpet_from_file_2, concpet_from_file_3 = (
gr.State(),
gr.State(),
gr.State(),
)
use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3 = (
gr.State(),
gr.State(),
gr.State(),
)
with gr.Row(equal_height=True):
with gr.Column(elem_classes="card"):
base_image = gr.Image(
label="Base Image (Required)", type="numpy", height=400, width=400
)
for idx in (1, 2, 3):
with gr.Column(elem_classes="card"):
locals()[f"concept_image{idx}"] = gr.Image(
label=f"Concept Image {idx}"
if idx == 1
else f"Concept {idx} (Optional)",
type="numpy",
height=400,
width=400,
)
locals()[f"concept_name{idx}"] = gr.Dropdown(
concept_options,
label=f"Concept {idx}",
value=None if idx != 1 else "age",
info="Pick concept type",
)
with gr.Accordion("💡 Or use a new concept 👇", open=False):
gr.Markdown(
"1. Upload a file with **>100** text variations<br>"
"2. Tip: Ask an LLM to list variations."
)
if idx == 1:
concept_file_1 = gr.File(
label="Concept variations", file_types=["text"]
)
elif idx == 2:
concept_file_2 = gr.File(
label="Concept variations", file_types=["text"]
)
else:
concept_file_3 = gr.File(
label="Concept variations", file_types=["text"]
)
with gr.Column(elem_classes="card"):
with gr.Accordion("⚙️ Advanced options", open=False):
prompt = gr.Textbox(
label="Guidance Prompt (Optional)",
placeholder="Optional text prompt to guide generation",
)
num_inference_steps = gr.Slider(1, 50, 30, step=1, label="Num steps")
with gr.Row():
scale = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="Scale")
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Number(value=0, label="Seed", precision=0)
gr.Markdown(
"If a concept is not showing enough, **increase rank** ⬇️"
)
with gr.Row():
rank1 = gr.Slider(1, 150, 30, step=1, label="Rank concept 1")
rank2 = gr.Slider(1, 150, 30, step=1, label="Rank concept 2")
rank3 = gr.Slider(1, 150, 30, step=1, label="Rank concept 3")
with gr.Column(elem_classes="card"):
output_image = gr.Image(show_label=False, height=480)
submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
gr.Markdown("### 🔥 Ready-made examples")
gr.Gallery(example_gallery, label="클릭해서 미리보기", columns=[3], height="auto")
gr.Examples(
examples,
inputs=[
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
],
outputs=[output_image],
fn=generate_examples,
cache_examples=False,
)
# Upload hooks
concept_file_1.upload(
get_text_embeddings,
[concept_file_1],
[concpet_from_file_1, use_concpet_from_file_1],
)
concept_file_2.upload(
get_text_embeddings,
[concept_file_2],
[concpet_from_file_2, use_concpet_from_file_2],
)
concept_file_3.upload(
get_text_embeddings,
[concept_file_3],
[concpet_from_file_3, use_concpet_from_file_3],
)
concept_file_1.delete(
lambda _: False, [concept_file_1], [use_concpet_from_file_1]
)
concept_file_2.delete(
lambda _: False, [concept_file_2], [use_concpet_from_file_2]
)
concept_file_3.delete(
lambda _: False, [concept_file_3], [use_concpet_from_file_3]
)
# Dropdown auto-rank
concept_name1.select(change_rank_default, [concept_name1], [rank1])
concept_name2.select(change_rank_default, [concept_name2], [rank2])
concept_name3.select(change_rank_default, [concept_name3], [rank3])
# Auto-match on upload
concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
# Generate chain
submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed).then(
process_and_display,
[
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
concpet_from_file_1,
concpet_from_file_2,
concpet_from_file_3,
use_concpet_from_file_1,
use_concpet_from_file_2,
use_concpet_from_file_3,
],
[output_image],
)
# ─────────────────────────────
# 10 · Launch
# ─────────────────────────────
if __name__ == "__main__":
demo.launch()