MilindChawre's picture
Adding code for stable diffusion using text inversion
45b110b
raw
history blame
7.75 kB
import gradio as gr
import torch
import gc
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
# Initialize model and configurations
# At the top level, add global variables
pipe = None
device = None
elastic_transformer = None
def init_model():
global pipe, device
if pipe is not None:
return pipe, device
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch_dtype
).to(torch_device)
# Load SD concepts
concepts = {
"dreams": "sd-concepts-library/dreams",
"midjourney-style": "sd-concepts-library/midjourney-style",
"moebius": "sd-concepts-library/moebius",
"marc-allante": "sd-concepts-library/style-of-marc-allante",
"wlop": "sd-concepts-library/wlop-style"
}
for concept in concepts.values():
pipe.load_textual_inversion(concept, mean_resizing=False)
device = torch_device
return pipe, device
def init_transformers(device):
global elastic_transformer
if elastic_transformer is not None:
return elastic_transformer
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device)
return elastic_transformer
# Add after init_transformers and before generate_images
def image_loss(images, loss_type, device, elastic_transformer):
if loss_type == 'blue':
error = torch.abs(images[:,2] - 0.9).mean()
return error.to(device)
elif loss_type == 'elastic':
transformed_imgs = elastic_transformer(images)
error = torch.abs(transformed_imgs - images).mean()
return error.to(device)
elif loss_type == 'symmetry':
flipped_image = torch.flip(images, [3])
error = F.mse_loss(images, flipped_image)
return error.to(device)
elif loss_type == 'saturation':
transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10)
error = torch.abs(transformed_imgs - images).mean()
return error.to(device)
else:
return torch.tensor(0.0).to(device)
def generate_images(prompt, concept):
global pipe, device, elastic_transformer
if pipe is None:
pipe, device = init_model()
if elastic_transformer is None:
elastic_transformer = init_transformers(device)
# Configuration
height, width = 384, 384
guidance_scale = 8
num_inference_steps = 45
loss_scale = 10.0
# Create scheduler
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000
)
pipe.scheduler = scheduler # Set the scheduler
# Create prompt text
prompt_text = f"{prompt} {concept}"
# Predefined seeds for each loss function
seeds = {
'none': 42,
'blue': 123,
'elastic': 456,
'symmetry': 789,
'saturation': 1000
}
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
images = []
progress = gr.Progress()
# Generate image for each loss function
for idx, loss_type in enumerate(loss_functions):
progress(idx/len(loss_functions), f"Generating {loss_type} image...")
generator = torch.manual_seed(seeds[loss_type])
# Generate base image
try:
output = pipe(
prompt_text,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)
except Exception as e:
print(f"Error generating image: {e}")
return None
# Apply loss function if not 'none'
if loss_type != 'none':
try:
# Convert PIL image to tensor and move to device
image_tensor = T.ToTensor()(output.images[0]).unsqueeze(0).to(device)
# Apply loss and update image
loss = image_loss(image_tensor, loss_type, device, elastic_transformer)
image_tensor = image_tensor - loss_scale * loss
# Move back to CPU and convert to PIL
image = T.ToPILImage()(image_tensor.cpu().squeeze(0).clamp(0, 1))
except Exception as e:
print(f"Error applying {loss_type} loss: {e}")
image = output.images[0] # Use original image if loss fails
else:
image = output.images[0]
# Add image with its label
try:
# Ensure image is in correct format (PIL.Image)
if not isinstance(image, Image.Image):
print(f"Warning: Converting {loss_type} image to PIL format")
image = Image.fromarray(image)
# Add tuple of (image, label) to list
images.append((image, f"{loss_type.capitalize()} Loss"))
print(f"Added {loss_type} image to gallery") # Debug print
except Exception as e:
print(f"Error adding {loss_type} image to gallery: {e}")
continue
# Clear GPU memory after each image
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Return all generated images
print(f"Returning {len(images)} images")
if not images:
return None
return images
def create_interface():
default_prompts = [
"A realistic image of Boy with a cowboy hat in the style of",
"A realistic image of Rabbit in a spacesuit in the style of",
"A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of"
]
concepts = [
"dreams",
"midjourney-style",
"moebius",
"marc-allante",
"wlop"
]
interface = gr.Interface(
fn=generate_images,
inputs=[
gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True),
gr.Dropdown(choices=concepts, label="Select SD Concept")
],
outputs=gr.Gallery(
label="Generated Images (From Left to Right: Original, Blue Loss, Elastic Loss, Symmetry Loss, Saturation Loss)",
show_label=True,
elem_id="gallery",
columns=5,
rows=1,
height=512,
object_fit="contain"
), # Simplified Gallery definition
title="Stable Diffusion using Text Inversion",
description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side:
1. Original Image (No Loss)
2. Blue Channel Loss - Enhances blue tones
3. Elastic Loss - Adds elastic deformation
4. Symmetry Loss - Enforces symmetrical features
5. Saturation Loss - Modifies color saturation
Note: Image generation may take several minutes. Please be patient while the images are being processed.""",
flagging_mode="never" # Updated from allow_flagging
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.queue(max_size=5) # Simplified queue configuration
interface.launch(
share=True,
server_name="0.0.0.0",
max_threads=1
)