import gradio as gr import torch from torch import autocast from diffusers import StableDiffusionPipeline import random from huggingface_hub import hf_hub_download import os # Initialize the model model_id = "CompVis/stable-diffusion-v1-4" device = "cuda" if torch.cuda.is_available() else "cpu" # List of concept embeddings to use concepts = [ "sd-concepts-library/sword-lily-flowers102", "sd-concepts-library/azalea-flowers102", "sd-concepts-library/samurai-jack", "sd-concepts-library/wu-shi-art", "sd-concepts-library/wu-shi" ] def download_concept_embedding(concept_name): try: # Download the learned_embeds.bin file from the Hub embed_path = hf_hub_download( repo_id=concept_name, filename="learned_embeds.bin", repo_type="model" ) return embed_path except Exception as e: print(f"Error downloading {concept_name}: {str(e)}") return None def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer): loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") # Add the concept token to tokenizer token = list(loaded_learned_embeds.keys())[0] num_added_tokens = tokenizer.add_tokens(token) # Resize token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # Add the concept embedding token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token] return token def generate_images(prompt): images = [] # Load base pipeline pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) for concept in concepts: # Download and load concept embedding embed_path = download_concept_embedding(concept) if embed_path is None: continue try: token = load_learned_embed_in_clip( embed_path, pipe.text_encoder, pipe.tokenizer ) # Generate random seed seed = random.randint(1, 999999) generator = torch.Generator(device=device).manual_seed(seed) # Add concept token to prompt concept_prompt = f"{token} {prompt}" # Generate image with autocast(device): image = pipe( concept_prompt, num_inference_steps=50, generator=generator, guidance_scale=7.5 ).images[0] images.append(image) # Clear concept from pipeline pipe.tokenizer.remove_tokens([token]) pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) except Exception as e: print(f"Error processing concept {concept}: {str(e)}") continue return images if images else [None] * 5 # Create Gradio interface iface = gr.Interface( fn=generate_images, inputs=gr.Textbox(label="Enter your prompt"), outputs=[gr.Image() for _ in range(5)], title="Multi-Concept Stable Diffusion Generator", description="Generate images using 5 different concepts from the SD Concepts Library" ) # Launch the app iface.launch()