Spaces:
Running
Running
import gradio as gr | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
import random | |
# Initialize the model | |
model_id = "CompVis/stable-diffusion-v1-4" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# List of concept embeddings to use | |
concepts = [ | |
"sd-concepts-library/sword-lily-flowers102", | |
"sd-concepts-library/azalea-flowers102", | |
"sd-concepts-library/samurai-jack", | |
"sd-concepts-library/wu-shi-art", | |
"sd-concepts-library/wu-shi" | |
] | |
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer): | |
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
# Add the concept token to tokenizer | |
token = list(loaded_learned_embeds.keys())[0] | |
num_added_tokens = tokenizer.add_tokens(token) | |
# Resize token embeddings | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# Add the concept embedding | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token] | |
return token | |
def generate_images(prompt): | |
images = [] | |
# Load base pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
for concept in concepts: | |
# Load concept embedding | |
token = load_learned_embed_in_clip( | |
f"{concept}/blob/main/learned_embeds.bin", | |
pipe.text_encoder, | |
pipe.tokenizer | |
) | |
# Generate random seed | |
seed = random.randint(1, 999999) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Add concept token to prompt | |
concept_prompt = f"{token} {prompt}" | |
# Generate image | |
with autocast(device): | |
image = pipe( | |
concept_prompt, | |
num_inference_steps=50, | |
generator=generator, | |
guidance_scale=7.5 | |
).images[0] | |
images.append(image) | |
# Clear concept from pipeline | |
pipe.tokenizer.remove_tokens([token]) | |
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) | |
return images | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_images, | |
inputs=gr.Textbox(label="Enter your prompt"), | |
outputs=[gr.Image() for _ in range(5)], | |
title="Multi-Concept Stable Diffusion Generator", | |
description="Generate images using 5 different concepts from the SD Concepts Library" | |
) | |
# Launch the app | |
iface.launch() |