Spaces:
Sleeping
Sleeping
File size: 3,897 Bytes
7a41953 6df8da7 7a41953 6269d98 7a41953 f67fd7b 7a41953 6df8da7 7a41953 6269d98 7a41953 6df8da7 6269d98 6df8da7 f67fd7b 6df8da7 6269d98 6df8da7 7a41953 6269d98 7a41953 6269d98 7a41953 6269d98 7a41953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import random
from huggingface_hub import hf_hub_download
import os
# Initialize the model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
# List of concept embeddings compatible with SD v1.4
concepts = [
"sd-concepts-library/cat-toy"
# "sd-concepts-library/disco-diffusion-style",
# "sd-concepts-library/modern-disney-style",
# "sd-concepts-library/charliebo-artstyle",
# "sd-concepts-library/redshift-render-style"
]
def download_concept_embedding(concept_name):
try:
# Download the learned_embeds.bin file from the Hub
embed_path = hf_hub_download(
repo_id=concept_name,
filename="learned_embeds.bin",
repo_type="model"
)
return embed_path
except Exception as e:
print(f"Error downloading {concept_name}: {str(e)}")
return None
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# Add the concept token to tokenizer
token = list(loaded_learned_embeds.keys())[0]
num_added_tokens = tokenizer.add_tokens(token)
# Resize token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# Add the concept embedding
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
return token
def generate_images(prompt):
images = []
failed_concepts = []
# Load base pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
for concept in concepts:
try:
# Download and load concept embedding
embed_path = download_concept_embedding(concept)
if embed_path is None:
failed_concepts.append(concept)
continue
token = load_learned_embed_in_clip(
embed_path,
pipe.text_encoder,
pipe.tokenizer
)
# Generate random seed
seed = random.randint(1, 999999)
generator = torch.Generator(device=device).manual_seed(seed)
# Add concept token to prompt
concept_prompt = f"{token} {prompt}"
# Generate image
with autocast(device):
image = pipe(
concept_prompt,
num_inference_steps=20,
generator=generator,
guidance_scale=7.5
).images[0]
images.append(image)
# Clear concept from pipeline
pipe.tokenizer.remove_tokens([token])
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
except Exception as e:
print(f"Error processing concept {concept}: {str(e)}")
failed_concepts.append(concept)
continue
if failed_concepts:
print(f"Failed to process concepts: {', '.join(failed_concepts)}")
# Return available images, pad with None if some failed
while len(images) < 5:
images.append(None)
return images[:5]
# Create Gradio interface
iface = gr.Interface(
fn=generate_images,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)],
title="Multi-Concept Stable Diffusion Generator",
description="Generate images using 5 different artistic concepts from the SD Concepts Library"
)
# Launch the app
iface.launch() |