Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
import random | |
from huggingface_hub import hf_hub_download | |
import os | |
# Initialize the model | |
model_id = "CompVis/stable-diffusion-v1-4" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# List of concept embeddings to use | |
concepts = [ | |
"sd-concepts-library/sword-lily-flowers102", | |
"sd-concepts-library/azalea-flowers102", | |
"sd-concepts-library/samurai-jack", | |
"sd-concepts-library/wu-shi-art", | |
"sd-concepts-library/wu-shi" | |
] | |
def download_concept_embedding(concept_name): | |
try: | |
# Download the learned_embeds.bin file from the Hub | |
embed_path = hf_hub_download( | |
repo_id=concept_name, | |
filename="learned_embeds.bin", | |
repo_type="model" | |
) | |
return embed_path | |
except Exception as e: | |
print(f"Error downloading {concept_name}: {str(e)}") | |
return None | |
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer): | |
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
# Add the concept token to tokenizer | |
token = list(loaded_learned_embeds.keys())[0] | |
num_added_tokens = tokenizer.add_tokens(token) | |
# Resize token embeddings | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# Add the concept embedding | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token] | |
return token | |
def generate_images(prompt): | |
images = [] | |
# Load base pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
for concept in concepts: | |
# Download and load concept embedding | |
embed_path = download_concept_embedding(concept) | |
if embed_path is None: | |
continue | |
try: | |
token = load_learned_embed_in_clip( | |
embed_path, | |
pipe.text_encoder, | |
pipe.tokenizer | |
) | |
# Generate random seed | |
seed = random.randint(1, 999999) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Add concept token to prompt | |
concept_prompt = f"{token} {prompt}" | |
# Generate image | |
with autocast(device): | |
image = pipe( | |
concept_prompt, | |
num_inference_steps=50, | |
generator=generator, | |
guidance_scale=7.5 | |
).images[0] | |
images.append(image) | |
# Clear concept from pipeline | |
pipe.tokenizer.remove_tokens([token]) | |
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) | |
except Exception as e: | |
print(f"Error processing concept {concept}: {str(e)}") | |
continue | |
return images if images else [None] * 5 | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_images, | |
inputs=gr.Textbox(label="Enter your prompt"), | |
outputs=[gr.Image() for _ in range(5)], | |
title="Multi-Concept Stable Diffusion Generator", | |
description="Generate images using 5 different concepts from the SD Concepts Library" | |
) | |
# Launch the app | |
iface.launch() |