seawolf2357's picture
Update app.py
e146909 verified
raw
history blame
18.5 kB
# ===========================================
# IP-Composer 🌅✚🖌️ – FULL IMPROVED UI SCRIPT
# (기존 기능 그대로, UI·테마·레이아웃·갤러리 강화)
# ===========================================
import os, json, random, gc
import numpy as np
import torch
from PIL import Image
import gradio as gr
from gradio.themes import Soft # ★ NEW
from diffusers import StableDiffusionXLPipeline
import open_clip
from huggingface_hub import hf_hub_download
from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
from IP_Composer.perform_swap import (compute_dataset_embeds_svd,
get_modified_images_embeds_composition)
from IP_Composer.generate_text_embeddings import (load_descriptions,
generate_embeddings)
import spaces
# ─────────────────────────────
# 1 · Device
# ─────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
# ─────────────────────────────
# 2 · Stable-Diffusion XL
# ─────────────────────────────
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
)
# ─────────────────────────────
# 3 · IP-Adapter
# ─────────────────────────────
image_encoder_repo = 'h94/IP-Adapter'
image_encoder_subfolder = 'models/image_encoder'
ip_ckpt = hf_hub_download(
'h94/IP-Adapter',
subfolder="sdxl_models",
filename='ip-adapter_sdxl_vit-h.bin'
)
ip_model = IPAdapterXL(pipe, image_encoder_repo,
image_encoder_subfolder,
ip_ckpt, device)
# ─────────────────────────────
# 4 · CLIP
# ─────────────────────────────
clip_model, _, preprocess = open_clip.create_model_and_transforms(
'hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
)
clip_model.to(device)
tokenizer = open_clip.get_tokenizer(
'hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
)
# ─────────────────────────────
# 5 · Concept maps
# ─────────────────────────────
CONCEPTS_MAP = {
"age": "age_descriptions.npy",
"animal fur": "fur_descriptions.npy",
"dogs": "dog_descriptions.npy",
"emotions": "emotion_descriptions.npy",
"flowers": "flower_descriptions.npy",
"fruit/vegtable": "fruit_vegetable_descriptions.npy",
"outfit type": "outfit_descriptions.npy",
"outfit pattern (including color)": "outfit_pattern_descriptions.npy",
"patterns": "pattern_descriptions.npy",
"patterns (including color)": "pattern_descriptions_with_colors.npy",
"vehicle": "vehicle_descriptions.npy",
"daytime": "times_of_day_descriptions.npy",
"pose": "person_poses_descriptions.npy",
"season": "season_descriptions.npy",
"material": "material_descriptions_with_gems.npy"
}
RANKS_MAP = {
"age": 30, "animal fur": 80, "dogs": 30, "emotions": 30,
"flowers": 30, "fruit/vegtable": 30, "outfit type": 30,
"outfit pattern (including color)": 80, "patterns": 80,
"patterns (including color)": 80, "vehicle": 30,
"daytime": 30, "pose": 30, "season": 30, "material": 80
}
concept_options = list(CONCEPTS_MAP.keys())
# ─────────────────────────────
# 6 · Example tuples (base_img, c1_img, …)
# ─────────────────────────────
examples = [
['./IP_Composer/assets/patterns/base.jpg',
'./IP_Composer/assets/patterns/pattern.png',
'patterns (including color)', None, None, None, None,
80, 30, 30, None, 1.0, 0, 30],
['./IP_Composer/assets/flowers/base.png',
'./IP_Composer/assets/flowers/concept.png',
'flowers', None, None, None, None,
30, 30, 30, None, 1.0, 0, 30],
['./IP_Composer/assets/materials/base.png',
'./IP_Composer/assets/materials/concept.jpg',
'material', None, None, None, None,
80, 30, 30, None, 1.0, 0, 30],
# … (생략 없이 추가 가능)
]
# ----------------------------------------------------------
# 7 · Utility functions (unchanged except docstring tweaks)
# ----------------------------------------------------------
def generate_examples(base_image,
concept_image1, concept_name1,
concept_image2, concept_name2,
concept_image3, concept_name3,
rank1, rank2, rank3,
prompt, scale, seed, num_inference_steps):
return process_and_display(base_image,
concept_image1, concept_name1,
concept_image2, concept_name2,
concept_image3, concept_name3,
rank1, rank2, rank3,
prompt, scale, seed, num_inference_steps)
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
return random.randint(0, MAX_SEED) if randomize_seed else seed
def change_rank_default(concept_name): # rank 자동 조정
return RANKS_MAP.get(concept_name, 30)
@spaces.GPU
def match_image_to_concept(image):
if image is None:
return None
img_pil = Image.fromarray(image).convert("RGB")
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
similarities = {}
for concept_name, concept_file in CONCEPTS_MAP.items():
try:
embeds_path = f"./IP_Composer/text_embeddings/{concept_file}"
with open(embeds_path, "rb") as f:
concept_embeds = np.load(f)
sim_scores = []
for embed in concept_embeds:
sim = np.dot(img_embed.flatten()/np.linalg.norm(img_embed),
embed.flatten()/np.linalg.norm(embed))
sim_scores.append(sim)
sim_scores.sort(reverse=True)
similarities[concept_name] = np.mean(sim_scores[:5])
except Exception as e:
print(f"Concept {concept_name} error: {e}")
if similarities:
detected = max(similarities, key=similarities.get)
gr.Info(f"Image automatically matched to concept: {detected}")
return detected
return None
@spaces.GPU
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
image = preproc(pil_image)[np.newaxis, :, :, :]
with torch.no_grad():
embeds = model.encode_image(image.to(dev))
return embeds.cpu().detach().numpy()
@spaces.GPU
def process_images(
base_image,
concept_image1, concept_name1,
concept_image2=None, concept_name2=None,
concept_image3=None, concept_name3=None,
rank1=10, rank2=10, rank3=10,
prompt=None, scale=1.0, seed=420, num_inference_steps=50,
concpet_from_file_1=None, concpet_from_file_2=None, concpet_from_file_3=None,
use_concpet_from_file_1=False, use_concpet_from_file_2=False, use_concpet_from_file_3=False
):
base_pil = Image.fromarray(base_image).convert("RGB")
base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
concept_images, concept_descs, ranks = [], [], []
skip = [False, False, False]
# ─── concept 1
if concept_image1 is None:
return None, "Please upload at least one concept image"
concept_images.append(concept_image1)
if use_concpet_from_file_1 and concpet_from_file_1 is not None:
concept_descs.append(concpet_from_file_1); skip[0] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name1])
ranks.append(rank1)
# ─── concept 2
if concept_image2 is not None:
concept_images.append(concept_image2)
if use_concpet_from_file_2 and concpet_from_file_2 is not None:
concept_descs.append(concpet_from_file_2); skip[1] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name2])
ranks.append(rank2)
# ─── concept 3
if concept_image3 is not None:
concept_images.append(concept_image3)
if use_concpet_from_file_3 and concpet_from_file_3 is not None:
concept_descs.append(concpet_from_file_3); skip[2] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name3])
ranks.append(rank3)
concept_embeds, proj_mats = [], []
for i, concept in enumerate(concept_descs):
img_pil = Image.fromarray(concept_images[i]).convert("RGB")
concept_embeds.append(get_image_embeds(img_pil, clip_model, preprocess, device))
if skip[i]:
all_embeds = concept
else:
with open(f"./IP_Composer/text_embeddings/{concept}", "rb") as f:
all_embeds = np.load(f)
proj_mats.append(compute_dataset_embeds_svd(all_embeds, ranks[i]))
projections_data = [
{"embed": e, "projection_matrix": p}
for e, p in zip(concept_embeds, proj_mats)
]
modified_images = get_modified_images_embeds_composition(
base_embed, projections_data, ip_model,
prompt=prompt, scale=scale,
num_samples=1, seed=seed, num_inference_steps=num_inference_steps
)
return modified_images[0]
@spaces.GPU
def get_text_embeddings(concept_file):
descriptions = load_descriptions(concept_file)
embeddings = generate_embeddings(descriptions, clip_model,
tokenizer, device, batch_size=100)
return embeddings, True
def process_and_display(
base_image,
concept_image1, concept_name1="age",
concept_image2=None, concept_name2=None,
concept_image3=None, concept_name3=None,
rank1=30, rank2=30, rank3=30,
prompt=None, scale=1.0, seed=0, num_inference_steps=50,
concpet_from_file_1=None, concpet_from_file_2=None, concpet_from_file_3=None,
use_concpet_from_file_1=False, use_concpet_from_file_2=False, use_concpet_from_file_3=False
):
if base_image is None:
raise gr.Error("Please upload a base image")
if concept_image1 is None:
raise gr.Error("Choose at least one concept image")
return process_images(
base_image, concept_image1, concept_name1,
concept_image2, concept_name2,
concept_image3, concept_name3,
rank1, rank2, rank3,
prompt, scale, seed, num_inference_steps,
concpet_from_file_1, concpet_from_file_2, concpet_from_file_3,
use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3
)
# ----------------------------------------------------------
# 8 · 💄 THEME & CSS UPGRADE
# ----------------------------------------------------------
demo_theme = Soft( # ★ NEW
primary_hue="purple",
font=[gr.themes.GoogleFont("Inter")]
)
css = """
body{
background:#0f0c29;
background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
}
#header{ text-align:center;
padding:24px 0 8px;
font-weight:700;
font-size:2.1rem;
color:#ffffff;}
.gradio-container{max-width:1024px !important;margin:0 auto}
.card{
border-radius:18px;
background:#ffffff0d;
padding:18px 22px;
backdrop-filter:blur(6px);
}
.gr-image,.gr-video{border-radius:14px}
.gr-image:hover{box-shadow:0 0 0 4px #a855f7}
"""
# ----------------------------------------------------------
# 9 · 🖼️ Demo UI
# ----------------------------------------------------------
example_gallery = [
['./IP_Composer/assets/patterns/base.jpg', "Patterns demo"],
['./IP_Composer/assets/flowers/base.png', "Flowers demo"],
['./IP_Composer/assets/materials/base.png',"Material demo"],
]
with gr.Blocks(css=css, theme=demo_theme) as demo:
# ─── Header
gr.Markdown("<div id='header'>🌅 IP-Composer&nbsp;"
"<sup style='font-size:14px'>SDXL</sup></div>")
# ─── States for custom concepts
concpet_from_file_1 = gr.State()
concpet_from_file_2 = gr.State()
concpet_from_file_3 = gr.State()
use_concpet_from_file_1 = gr.State()
use_concpet_from_file_2 = gr.State()
use_concpet_from_file_3 = gr.State()
# ─── Main layout
with gr.Row(equal_height=True):
# Base image card
with gr.Column(elem_classes="card"):
base_image = gr.Image(label="Base Image (Required)",
type="numpy", height=400, width=400)
# Concept cards (1 · 2 · 3)
for idx in (1, 2, 3):
with gr.Column(elem_classes="card"):
locals()[f"concept_image{idx}"] = gr.Image(
label=f"Concept Image {idx}" if idx == 1 else f"Concept {idx} (Optional)",
type="numpy", height=400, width=400
)
locals()[f"concept_name{idx}"] = gr.Dropdown(
concept_options, label=f"Concept {idx}",
value=None if idx != 1 else "age",
info="Pick concept type"
)
with gr.Accordion("💡 Or use a new concept 👇", open=False):
gr.Markdown("1. Upload a file with **>100** text variations<br>"
"2. Tip: Ask an LLM to list variations.")
if idx == 1:
concept_file_1 = gr.File("Concept variations",
file_types=["text"])
elif idx == 2:
concept_file_2 = gr.File("Concept variations",
file_types=["text"])
else:
concept_file_3 = gr.File("Concept variations",
file_types=["text"])
# ─── Advanced options card (full width)
with gr.Column(elem_classes="card"):
with gr.Accordion("⚙️ Advanced options", open=False):
prompt = gr.Textbox(label="Guidance Prompt (Optional)",
placeholder="Optional text prompt to guide generation")
num_inference_steps = gr.Slider(1, 50, value=30, step=1,
label="Num steps")
with gr.Row():
scale = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Scale")
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Number(value=0, label="Seed", precision=0)
gr.Markdown("If a concept is not showing enough, **increase rank** ⬇️")
with gr.Row():
rank1 = gr.Slider(1, 150, value=30, step=1, label="Rank concept 1")
rank2 = gr.Slider(1, 150, value=30, step=1, label="Rank concept 2")
rank3 = gr.Slider(1, 150, value=30, step=1, label="Rank concept 3")
# ─── Output & Generate button
with gr.Column(elem_classes="card"):
output_image = gr.Image(show_label=False, height=480)
submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
# ─── Ready-made Gallery
gr.Markdown("### 🔥 Ready-made examples")
gr.Gallery(example_gallery, label="클릭해서 미리보기",
columns=[3], height="auto")
# ─── Example usage (kept for quick test)
gr.Examples(
examples,
inputs=[base_image, concept_image1, concept_name1,
concept_image2, concept_name2,
concept_image3, concept_name3,
rank1, rank2, rank3,
prompt, scale, seed, num_inference_steps],
outputs=[output_image],
fn=generate_examples,
cache_examples=False
)
# ─── File upload triggers
concept_file_1.upload(get_text_embeddings, [concept_file_1],
[concpet_from_file_1, use_concpet_from_file_1])
concept_file_2.upload(get_text_embeddings, [concept_file_2],
[concpet_from_file_2, use_concpet_from_file_2])
concept_file_3.upload(get_text_embeddings, [concept_file_3],
[concpet_from_file_3, use_concpet_from_file_3])
concept_file_1.delete(lambda x: False, [concept_file_1],
[use_concpet_from_file_1])
concept_file_2.delete(lambda x: False, [concept_file_2],
[use_concpet_from_file_2])
concept_file_3.delete(lambda x: False, [concept_file_3],
[use_concpet_from_file_3])
# ─── Dropdown auto-rank
concept_name1.select(change_rank_default, [concept_name1], [rank1])
concept_name2.select(change_rank_default, [concept_name2], [rank2])
concept_name3.select(change_rank_default, [concept_name3], [rank3])
# ─── Auto-match concept type on image upload
concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
# ─── Generate click chain
submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed) \
.then(process_and_display,
[base_image, concept_image1, concept_name1,
concept_image2, concept_name2,
concept_image3, concept_name3,
rank1, rank2, rank3,
prompt, scale, seed, num_inference_steps,
concpet_from_file_1, concpet_from_file_2, concpet_from_file_3,
use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3],
[output_image])
# ─────────────────────────────
# 10 · Launch
# ─────────────────────────────
if __name__ == "__main__":
demo.launch()