seawolf2357's picture
Update app.py
7ad4e34 verified
raw
history blame
18.1 kB
import os, json, random, gc
import numpy as np
import torch
from PIL import Image
import gradio as gr
from gradio.themes import Soft
from diffusers import StableDiffusionXLPipeline
import open_clip
from huggingface_hub import hf_hub_download
from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
from IP_Composer.perform_swap import (
compute_dataset_embeds_svd,
get_modified_images_embeds_composition,
)
from IP_Composer.generate_text_embeddings import (
load_descriptions,
generate_embeddings,
)
import spaces
# ─────────────────────────────
# 1 · Device
# ─────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
# ─────────────────────────────
# 2 · Stable-Diffusion XL
# ─────────────────────────────
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
)
# ─────────────────────────────
# 3 · IP-Adapter
# ─────────────────────────────
image_encoder_repo = "h94/IP-Adapter"
image_encoder_subfolder = "models/image_encoder"
ip_ckpt = hf_hub_download(
"h94/IP-Adapter", subfolder="sdxl_models", filename="ip-adapter_sdxl_vit-h.bin"
)
ip_model = IPAdapterXL(
pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device
)
# ─────────────────────────────
# 4 · CLIP
# ─────────────────────────────
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
)
clip_model.to(device)
tokenizer = open_clip.get_tokenizer(
"hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
)
# ─────────────────────────────
# 5 · Concept maps
# ─────────────────────────────
CONCEPTS_MAP = {
"age": "age_descriptions.npy",
"animal fur": "fur_descriptions.npy",
"dogs": "dog_descriptions.npy",
"emotions": "emotion_descriptions.npy",
"flowers": "flower_descriptions.npy",
"fruit/vegtable": "fruit_vegetable_descriptions.npy",
"outfit type": "outfit_descriptions.npy",
"outfit pattern (including color)": "outfit_pattern_descriptions.npy",
"patterns": "pattern_descriptions.npy",
"patterns (including color)": "pattern_descriptions_with_colors.npy",
"vehicle": "vehicle_descriptions.npy",
"daytime": "times_of_day_descriptions.npy",
"pose": "person_poses_descriptions.npy",
"season": "season_descriptions.npy",
"material": "material_descriptions_with_gems.npy",
}
RANKS_MAP = {
"age": 30,
"animal fur": 80,
"dogs": 30,
"emotions": 30,
"flowers": 30,
"fruit/vegtable": 30,
"outfit type": 30,
"outfit pattern (including color)": 80,
"patterns": 80,
"patterns (including color)": 80,
"vehicle": 30,
"daytime": 30,
"pose": 30,
"season": 30,
"material": 80,
}
concept_options = list(CONCEPTS_MAP.keys())
# ─────────────────────────────
# 6 · Example tuples (base_img, c1_img, …)
# ─────────────────────────────
examples = [
[
"./IP_Composer/assets/patterns/base.jpg",
"./IP_Composer/assets/patterns/pattern.png",
"patterns (including color)",
None,
None,
None,
None,
80,
30,
30,
None,
1.0,
0,
30,
],
[
"./IP_Composer/assets/flowers/base.png",
"./IP_Composer/assets/flowers/concept.png",
"flowers",
None,
None,
None,
None,
30,
30,
30,
None,
1.0,
0,
30,
],
[
"./IP_Composer/assets/materials/base.png",
"./IP_Composer/assets/materials/concept.jpg",
"material",
None,
None,
None,
None,
80,
30,
30,
None,
1.0,
0,
30,
],
]
# ----------------------------------------------------------
# 7 · Utility functions
# ----------------------------------------------------------
def generate_examples(
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
):
return process_and_display(
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
)
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
return random.randint(0, MAX_SEED) if randomize_seed else seed
def change_rank_default(concept_name):
return RANKS_MAP.get(concept_name, 30)
@spaces.GPU
def match_image_to_concept(image):
if image is None:
return None
img_pil = Image.fromarray(image).convert("RGB")
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
sims = {}
for cname, cfile in CONCEPTS_MAP.items():
try:
with open(f"./IP_Composer/text_embeddings/{cfile}", "rb") as f:
embeds = np.load(f)
scores = []
for e in embeds:
s = np.dot(
img_embed.flatten() / np.linalg.norm(img_embed),
e.flatten() / np.linalg.norm(e),
)
scores.append(s)
scores.sort(reverse=True)
sims[cname] = np.mean(scores[:5])
except Exception as e:
print(cname, "error:", e)
if sims:
best = max(sims, key=sims.get)
gr.Info(f"Image automatically matched to concept: {best}")
return best
return None
@spaces.GPU
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
image = preproc(pil_image)[np.newaxis, :, :, :]
with torch.no_grad():
embeds = model.encode_image(image.to(dev))
return embeds.cpu().detach().numpy()
@spaces.GPU
def process_images(
base_image,
concept_image1,
concept_name1,
concept_image2=None,
concept_name2=None,
concept_image3=None,
concept_name3=None,
rank1=10,
rank2=10,
rank3=10,
prompt=None,
scale=1.0,
seed=420,
num_inference_steps=50,
concpet_from_file_1=None,
concpet_from_file_2=None,
concpet_from_file_3=None,
use_concpet_from_file_1=False,
use_concpet_from_file_2=False,
use_concpet_from_file_3=False,
):
base_pil = Image.fromarray(base_image).convert("RGB")
base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
concept_images, concept_descs, ranks = [], [], []
skip = [False, False, False]
# concept 1
if concept_image1 is None:
return None, "Please upload at least one concept image"
concept_images.append(concept_image1)
if use_concpet_from_file_1 and concpet_from_file_1 is not None:
concept_descs.append(concpet_from_file_1)
skip[0] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name1])
ranks.append(rank1)
# concept 2
if concept_image2 is not None:
concept_images.append(concept_image2)
if use_concpet_from_file_2 and concpet_from_file_2 is not None:
concept_descs.append(concpet_from_file_2)
skip[1] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name2])
ranks.append(rank2)
# concept 3
if concept_image3 is not None:
concept_images.append(concept_image3)
if use_concpet_from_file_3 and concpet_from_file_3 is not None:
concept_descs.append(concpet_from_file_3)
skip[2] = True
else:
concept_descs.append(CONCEPTS_MAP[concept_name3])
ranks.append(rank3)
concept_embeds, proj_mats = [], []
for i, concept in enumerate(concept_descs):
img_pil = Image.fromarray(concept_images[i]).convert("RGB")
concept_embeds.append(get_image_embeds(img_pil, clip_model, preprocess, device))
if skip[i]:
all_embeds = concept
else:
with open(f"./IP_Composer/text_embeddings/{concept}", "rb") as f:
all_embeds = np.load(f)
proj_mats.append(compute_dataset_embeds_svd(all_embeds, ranks[i]))
projections_data = [
{"embed": e, "projection_matrix": p}
for e, p in zip(concept_embeds, proj_mats)
]
modified = get_modified_images_embeds_composition(
base_embed,
projections_data,
ip_model,
prompt=prompt,
scale=scale,
num_samples=1,
seed=seed,
num_inference_steps=num_inference_steps,
)
return modified[0]
@spaces.GPU
def get_text_embeddings(concept_file):
descs = load_descriptions(concept_file)
embeds = generate_embeddings(descs, clip_model, tokenizer, device, batch_size=100)
return embeds, True
def process_and_display(
base_image,
concept_image1,
concept_name1="age",
concept_image2=None,
concept_name2=None,
concept_image3=None,
concept_name3=None,
rank1=30,
rank2=30,
rank3=30,
prompt=None,
scale=1.0,
seed=0,
num_inference_steps=50,
concpet_from_file_1=None,
concpet_from_file_2=None,
concpet_from_file_3=None,
use_concpet_from_file_1=False,
use_concpet_from_file_2=False,
use_concpet_from_file_3=False,
):
if base_image is None:
raise gr.Error("Please upload a base image")
if concept_image1 is None:
raise gr.Error("Choose at least one concept image")
return process_images(
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
concpet_from_file_1,
concpet_from_file_2,
concpet_from_file_3,
use_concpet_from_file_1,
use_concpet_from_file_2,
use_concpet_from_file_3,
)
# ----------------------------------------------------------
# 8 · THEME & CSS
# ----------------------------------------------------------
demo_theme = Soft(primary_hue="purple", font=[gr.themes.GoogleFont("Inter")])
css = """
body{
background:#0f0c29;
background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
}
#header{
text-align:center;
padding:24px 0 8px;
font-weight:700;
font-size:2.1rem;
color:#ffffff;
}
.gradio-container{max-width:1024px !important;margin:0 auto}
.card{
border-radius:18px;
background:#ffffff0d;
padding:18px 22px;
backdrop-filter:blur(6px);
}
.gr-image,.gr-video{border-radius:14px}
.gr-image:hover{box-shadow:0 0 0 4px #a855f7}
"""
# ----------------------------------------------------------
# 9 · UI
# ----------------------------------------------------------
example_gallery = [
["./IP_Composer/assets/patterns/base.jpg", "Patterns demo"],
["./IP_Composer/assets/flowers/base.png", "Flowers demo"],
["./IP_Composer/assets/materials/base.png", "Material demo"],
]
with gr.Blocks(css=css, theme=demo_theme) as demo:
gr.Markdown(
"<div id='header'>🌅 IP-Composer&nbsp;"
"<sup style='font-size:14px'>SDXL</sup></div>"
)
concpet_from_file_1, concpet_from_file_2, concpet_from_file_3 = (
gr.State(),
gr.State(),
gr.State(),
)
use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3 = (
gr.State(),
gr.State(),
gr.State(),
)
with gr.Row(equal_height=True):
with gr.Column(elem_classes="card"):
base_image = gr.Image(
label="Base Image (Required)", type="numpy", height=400, width=400
)
for idx in (1, 2, 3):
with gr.Column(elem_classes="card"):
locals()[f"concept_image{idx}"] = gr.Image(
label=f"Concept Image {idx}"
if idx == 1
else f"Concept {idx} (Optional)",
type="numpy",
height=400,
width=400,
)
locals()[f"concept_name{idx}"] = gr.Dropdown(
concept_options,
label=f"Concept {idx}",
value=None if idx != 1 else "age",
info="Pick concept type",
)
with gr.Accordion("💡 Or use a new concept 👇", open=False):
gr.Markdown(
"1. Upload a file with **>100** text variations<br>"
"2. Tip: Ask an LLM to list variations."
)
if idx == 1:
concept_file_1 = gr.File(
label="Concept variations", file_types=["text"]
)
elif idx == 2:
concept_file_2 = gr.File(
label="Concept variations", file_types=["text"]
)
else:
concept_file_3 = gr.File(
label="Concept variations", file_types=["text"]
)
with gr.Column(elem_classes="card"):
with gr.Accordion("⚙️ Advanced options", open=False):
prompt = gr.Textbox(
label="Guidance Prompt (Optional)",
placeholder="Optional text prompt to guide generation",
)
num_inference_steps = gr.Slider(1, 50, 30, step=1, label="Num steps")
with gr.Row():
scale = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="Scale")
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Number(value=0, label="Seed", precision=0)
gr.Markdown(
"If a concept is not showing enough, **increase rank** ⬇️"
)
with gr.Row():
rank1 = gr.Slider(1, 150, 30, step=1, label="Rank concept 1")
rank2 = gr.Slider(1, 150, 30, step=1, label="Rank concept 2")
rank3 = gr.Slider(1, 150, 30, step=1, label="Rank concept 3")
with gr.Column(elem_classes="card"):
output_image = gr.Image(show_label=False, height=480)
submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
gr.Markdown("### 🔥 Ready-made examples")
gr.Gallery(example_gallery, label="Preview", columns=[3], height="auto")
gr.Examples(
examples,
inputs=[
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
],
outputs=[output_image],
fn=generate_examples,
cache_examples=False,
)
# Upload hooks
concept_file_1.upload(
get_text_embeddings,
[concept_file_1],
[concpet_from_file_1, use_concpet_from_file_1],
)
concept_file_2.upload(
get_text_embeddings,
[concept_file_2],
[concpet_from_file_2, use_concpet_from_file_2],
)
concept_file_3.upload(
get_text_embeddings,
[concept_file_3],
[concpet_from_file_3, use_concpet_from_file_3],
)
concept_file_1.delete(
lambda _: False, [concept_file_1], [use_concpet_from_file_1]
)
concept_file_2.delete(
lambda _: False, [concept_file_2], [use_concpet_from_file_2]
)
concept_file_3.delete(
lambda _: False, [concept_file_3], [use_concpet_from_file_3]
)
# Dropdown auto-rank
concept_name1.select(change_rank_default, [concept_name1], [rank1])
concept_name2.select(change_rank_default, [concept_name2], [rank2])
concept_name3.select(change_rank_default, [concept_name3], [rank3])
# Auto-match on upload
concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
# Generate chain
submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed).then(
process_and_display,
[
base_image,
concept_image1,
concept_name1,
concept_image2,
concept_name2,
concept_image3,
concept_name3,
rank1,
rank2,
rank3,
prompt,
scale,
seed,
num_inference_steps,
concpet_from_file_1,
concpet_from_file_2,
concpet_from_file_3,
use_concpet_from_file_1,
use_concpet_from_file_2,
use_concpet_from_file_3,
],
[output_image],
)
# ─────────────────────────────
# 10 · Launch
# ─────────────────────────────
if __name__ == "__main__":
demo.launch()