import os, json, random, gc import numpy as np import torch from PIL import Image import gradio as gr from gradio.themes import Soft from diffusers import StableDiffusionXLPipeline import open_clip from huggingface_hub import hf_hub_download from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL from IP_Composer.perform_swap import ( compute_dataset_embeds_svd, get_modified_images_embeds_composition, ) from IP_Composer.generate_text_embeddings import ( load_descriptions, generate_embeddings, ) import spaces # ───────────────────────────── # 1 · Device # ───────────────────────────── device = "cuda" if torch.cuda.is_available() else "cpu" # ───────────────────────────── # 2 · Stable-Diffusion XL # ───────────────────────────── base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" pipe = StableDiffusionXLPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, add_watermarker=False, ) # ───────────────────────────── # 3 · IP-Adapter # ───────────────────────────── image_encoder_repo = "h94/IP-Adapter" image_encoder_subfolder = "models/image_encoder" ip_ckpt = hf_hub_download( "h94/IP-Adapter", subfolder="sdxl_models", filename="ip-adapter_sdxl_vit-h.bin" ) ip_model = IPAdapterXL( pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device ) # ───────────────────────────── # 4 · CLIP # ───────────────────────────── clip_model, _, preprocess = open_clip.create_model_and_transforms( "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K" ) clip_model.to(device) tokenizer = open_clip.get_tokenizer( "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K" ) # ───────────────────────────── # 5 · Concept maps # ───────────────────────────── CONCEPTS_MAP = { "age": "age_descriptions.npy", "animal fur": "fur_descriptions.npy", "dogs": "dog_descriptions.npy", "emotions": "emotion_descriptions.npy", "flowers": "flower_descriptions.npy", "fruit/vegtable": "fruit_vegetable_descriptions.npy", "outfit type": "outfit_descriptions.npy", "outfit pattern (including color)": "outfit_pattern_descriptions.npy", "patterns": "pattern_descriptions.npy", "patterns (including color)": "pattern_descriptions_with_colors.npy", "vehicle": "vehicle_descriptions.npy", "daytime": "times_of_day_descriptions.npy", "pose": "person_poses_descriptions.npy", "season": "season_descriptions.npy", "material": "material_descriptions_with_gems.npy", } RANKS_MAP = { "age": 30, "animal fur": 80, "dogs": 30, "emotions": 30, "flowers": 30, "fruit/vegtable": 30, "outfit type": 30, "outfit pattern (including color)": 80, "patterns": 80, "patterns (including color)": 80, "vehicle": 30, "daytime": 30, "pose": 30, "season": 30, "material": 80, } concept_options = list(CONCEPTS_MAP.keys()) # ───────────────────────────── # 6 · Example tuples (base_img, c1_img, …) # ───────────────────────────── examples = [ [ "./IP_Composer/assets/patterns/base.jpg", "./IP_Composer/assets/patterns/pattern.png", "patterns (including color)", None, None, None, None, 80, 30, 30, None, 1.0, 0, 30, ], [ "./IP_Composer/assets/flowers/base.png", "./IP_Composer/assets/flowers/concept.png", "flowers", None, None, None, None, 30, 30, 30, None, 1.0, 0, 30, ], [ "./IP_Composer/assets/materials/base.png", "./IP_Composer/assets/materials/concept.jpg", "material", None, None, None, None, 80, 30, 30, None, 1.0, 0, 30, ], ] # ---------------------------------------------------------- # 7 · Utility functions # ---------------------------------------------------------- def generate_examples( base_image, concept_image1, concept_name1, concept_image2, concept_name2, concept_image3, concept_name3, rank1, rank2, rank3, prompt, scale, seed, num_inference_steps, ): return process_and_display( base_image, concept_image1, concept_name1, concept_image2, concept_name2, concept_image3, concept_name3, rank1, rank2, rank3, prompt, scale, seed, num_inference_steps, ) MAX_SEED = np.iinfo(np.int32).max def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: return random.randint(0, MAX_SEED) if randomize_seed else seed def change_rank_default(concept_name): return RANKS_MAP.get(concept_name, 30) @spaces.GPU def match_image_to_concept(image): if image is None: return None img_pil = Image.fromarray(image).convert("RGB") img_embed = get_image_embeds(img_pil, clip_model, preprocess, device) sims = {} for cname, cfile in CONCEPTS_MAP.items(): try: with open(f"./IP_Composer/text_embeddings/{cfile}", "rb") as f: embeds = np.load(f) scores = [] for e in embeds: s = np.dot( img_embed.flatten() / np.linalg.norm(img_embed), e.flatten() / np.linalg.norm(e), ) scores.append(s) scores.sort(reverse=True) sims[cname] = np.mean(scores[:5]) except Exception as e: print(cname, "error:", e) if sims: best = max(sims, key=sims.get) gr.Info(f"Image automatically matched to concept: {best}") return best return None @spaces.GPU def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device): image = preproc(pil_image)[np.newaxis, :, :, :] with torch.no_grad(): embeds = model.encode_image(image.to(dev)) return embeds.cpu().detach().numpy() @spaces.GPU def process_images( base_image, concept_image1, concept_name1, concept_image2=None, concept_name2=None, concept_image3=None, concept_name3=None, rank1=10, rank2=10, rank3=10, prompt=None, scale=1.0, seed=420, num_inference_steps=50, concpet_from_file_1=None, concpet_from_file_2=None, concpet_from_file_3=None, use_concpet_from_file_1=False, use_concpet_from_file_2=False, use_concpet_from_file_3=False, ): base_pil = Image.fromarray(base_image).convert("RGB") base_embed = get_image_embeds(base_pil, clip_model, preprocess, device) concept_images, concept_descs, ranks = [], [], [] skip = [False, False, False] # concept 1 if concept_image1 is None: return None, "Please upload at least one concept image" concept_images.append(concept_image1) if use_concpet_from_file_1 and concpet_from_file_1 is not None: concept_descs.append(concpet_from_file_1) skip[0] = True else: concept_descs.append(CONCEPTS_MAP[concept_name1]) ranks.append(rank1) # concept 2 if concept_image2 is not None: concept_images.append(concept_image2) if use_concpet_from_file_2 and concpet_from_file_2 is not None: concept_descs.append(concpet_from_file_2) skip[1] = True else: concept_descs.append(CONCEPTS_MAP[concept_name2]) ranks.append(rank2) # concept 3 if concept_image3 is not None: concept_images.append(concept_image3) if use_concpet_from_file_3 and concpet_from_file_3 is not None: concept_descs.append(concpet_from_file_3) skip[2] = True else: concept_descs.append(CONCEPTS_MAP[concept_name3]) ranks.append(rank3) concept_embeds, proj_mats = [], [] for i, concept in enumerate(concept_descs): img_pil = Image.fromarray(concept_images[i]).convert("RGB") concept_embeds.append(get_image_embeds(img_pil, clip_model, preprocess, device)) if skip[i]: all_embeds = concept else: with open(f"./IP_Composer/text_embeddings/{concept}", "rb") as f: all_embeds = np.load(f) proj_mats.append(compute_dataset_embeds_svd(all_embeds, ranks[i])) projections_data = [ {"embed": e, "projection_matrix": p} for e, p in zip(concept_embeds, proj_mats) ] modified = get_modified_images_embeds_composition( base_embed, projections_data, ip_model, prompt=prompt, scale=scale, num_samples=1, seed=seed, num_inference_steps=num_inference_steps, ) return modified[0] @spaces.GPU def get_text_embeddings(concept_file): descs = load_descriptions(concept_file) embeds = generate_embeddings(descs, clip_model, tokenizer, device, batch_size=100) return embeds, True def process_and_display( base_image, concept_image1, concept_name1="age", concept_image2=None, concept_name2=None, concept_image3=None, concept_name3=None, rank1=30, rank2=30, rank3=30, prompt=None, scale=1.0, seed=0, num_inference_steps=50, concpet_from_file_1=None, concpet_from_file_2=None, concpet_from_file_3=None, use_concpet_from_file_1=False, use_concpet_from_file_2=False, use_concpet_from_file_3=False, ): if base_image is None: raise gr.Error("Please upload a base image") if concept_image1 is None: raise gr.Error("Choose at least one concept image") return process_images( base_image, concept_image1, concept_name1, concept_image2, concept_name2, concept_image3, concept_name3, rank1, rank2, rank3, prompt, scale, seed, num_inference_steps, concpet_from_file_1, concpet_from_file_2, concpet_from_file_3, use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3, ) # ---------------------------------------------------------- # 8 · THEME & CSS # ---------------------------------------------------------- demo_theme = Soft(primary_hue="purple", font=[gr.themes.GoogleFont("Inter")]) css = """ body{ background:#0f0c29; background:linear-gradient(135deg,#0f0c29,#302b63,#24243e); } #header{ text-align:center; padding:24px 0 8px; font-weight:700; font-size:2.1rem; color:#ffffff; } .gradio-container{max-width:1024px !important;margin:0 auto} .card{ border-radius:18px; background:#ffffff0d; padding:18px 22px; backdrop-filter:blur(6px); } .gr-image,.gr-video{border-radius:14px} .gr-image:hover{box-shadow:0 0 0 4px #a855f7} """ # ---------------------------------------------------------- # 9 · UI # ---------------------------------------------------------- example_gallery = [ ["./IP_Composer/assets/patterns/base.jpg", "Patterns demo"], ["./IP_Composer/assets/flowers/base.png", "Flowers demo"], ["./IP_Composer/assets/materials/base.png", "Material demo"], ] with gr.Blocks(css=css, theme=demo_theme) as demo: gr.Markdown( "