satyanayak's picture
fixing the path of concept model's bin file
6df8da7
raw
history blame
3.49 kB
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import random
from huggingface_hub import hf_hub_download
import os
# Initialize the model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
# List of concept embeddings to use
concepts = [
"sd-concepts-library/sword-lily-flowers102",
"sd-concepts-library/azalea-flowers102",
"sd-concepts-library/samurai-jack",
"sd-concepts-library/wu-shi-art",
"sd-concepts-library/wu-shi"
]
def download_concept_embedding(concept_name):
try:
# Download the learned_embeds.bin file from the Hub
embed_path = hf_hub_download(
repo_id=concept_name,
filename="learned_embeds.bin",
repo_type="model"
)
return embed_path
except Exception as e:
print(f"Error downloading {concept_name}: {str(e)}")
return None
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# Add the concept token to tokenizer
token = list(loaded_learned_embeds.keys())[0]
num_added_tokens = tokenizer.add_tokens(token)
# Resize token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# Add the concept embedding
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
return token
def generate_images(prompt):
images = []
# Load base pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
for concept in concepts:
# Download and load concept embedding
embed_path = download_concept_embedding(concept)
if embed_path is None:
continue
try:
token = load_learned_embed_in_clip(
embed_path,
pipe.text_encoder,
pipe.tokenizer
)
# Generate random seed
seed = random.randint(1, 999999)
generator = torch.Generator(device=device).manual_seed(seed)
# Add concept token to prompt
concept_prompt = f"{token} {prompt}"
# Generate image
with autocast(device):
image = pipe(
concept_prompt,
num_inference_steps=50,
generator=generator,
guidance_scale=7.5
).images[0]
images.append(image)
# Clear concept from pipeline
pipe.tokenizer.remove_tokens([token])
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
except Exception as e:
print(f"Error processing concept {concept}: {str(e)}")
continue
return images if images else [None] * 5
# Create Gradio interface
iface = gr.Interface(
fn=generate_images,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=[gr.Image() for _ in range(5)],
title="Multi-Concept Stable Diffusion Generator",
description="Generate images using 5 different concepts from the SD Concepts Library"
)
# Launch the app
iface.launch()