MilindChawre's picture
Making changes in random seed value generation
663ae65
import gradio as gr
import torch
import gc
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
# Initialize model and configurations
# At the top level, add global variables
pipe = None
device = None
elastic_transformer = None
def init_model():
global pipe, device
if pipe is not None:
return pipe, device
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch_dtype
).to(torch_device)
# Load SD concepts
concepts = {
"dreams": "sd-concepts-library/dreams",
"midjourney-style": "sd-concepts-library/midjourney-style",
"moebius": "sd-concepts-library/moebius",
"marc-allante": "sd-concepts-library/style-of-marc-allante",
"wlop": "sd-concepts-library/wlop-style"
}
for concept in concepts.values():
pipe.load_textual_inversion(concept, mean_resizing=False)
device = torch_device
return pipe, device
def init_transformers(device):
global elastic_transformer
if elastic_transformer is not None:
return elastic_transformer
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device)
return elastic_transformer
# Add after init_transformers and before generate_images
def image_loss(images, loss_type, device, elastic_transformer):
if loss_type == 'blue':
# Reduced target blue value from 0.9 to 0.6 for more subtle effect
error = torch.abs(images[:,2] - 0.6).mean()
# Apply a lower scale specifically for blue loss
return (error * 0.3).to(device) # Reduced scaling factor
elif loss_type == 'elastic':
transformed_imgs = elastic_transformer(images)
error = torch.abs(transformed_imgs - images).mean()
return error.to(device)
elif loss_type == 'symmetry':
flipped_image = torch.flip(images, [3])
error = F.mse_loss(images, flipped_image)
return error.to(device)
elif loss_type == 'saturation':
transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10)
error = torch.abs(transformed_imgs - images).mean()
return error.to(device)
else:
return torch.tensor(0.0).to(device)
# Update configuration for faster generation
height, width = 384, 384 # Reduced from 512x512 to 384x384
guidance_scale = 8 # Increased from 7.5 to 8 for better prompt adherence
num_inference_steps = 45 # Using 45 steps for better quality
loss_scale = 150
def generate_images(prompt, concept):
global pipe, device, elastic_transformer
if pipe is None:
pipe, device = init_model()
if elastic_transformer is None:
elastic_transformer = init_transformers(device)
# Create prompt text and initialize results
prompt_text = f"{prompt} {concept}"
all_images = [] # Changed from images to all_images
# Process each loss type
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
progress = gr.Progress()
# Generate one random seed for all loss types
random_seed = torch.randint(1, 10000, (1,)).item() # Random seed between 1 and 9999
print(f"\nUsing random seed {random_seed} for all images")
for idx, loss_type in enumerate(loss_functions):
try:
print(f"\n[{loss_type.upper()}] Starting image generation...")
progress(idx/len(loss_functions), f"Starting {loss_type} image generation...")
# Better memory management
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
# Move inputs to correct device and dtype
# Remove incorrect device movement
# text_input = text_input.to(device) # Remove this line
# uncond_input = uncond_input.to(device) # Remove this line
# latents = latents.to(dtype=pipe.vae.dtype, device=device) # Remove this line
# Initialize scheduler and process text first
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000
)
scheduler.set_timesteps(num_inference_steps)
scheduler.timesteps = scheduler.timesteps.to(device)
# Process text embeddings
text_input = pipe.tokenizer(
[prompt_text],
padding='max_length',
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
uncond_input = pipe.tokenizer(
[""] * 1,
padding="max_length",
max_length=text_input.input_ids.shape[-1],
return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Generate initial latents with random seed
# Use the same seed for all loss types
generator = torch.manual_seed(random_seed)
latents = torch.randn(
(1, pipe.unet.config.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(device=device, dtype=pipe.unet.dtype)
latents = latents * scheduler.init_noise_sigma
# Diffusion process
total_steps = len(scheduler.timesteps)
for i, t in enumerate(scheduler.timesteps):
current_progress = (idx + (i / total_steps)) / len(loss_functions)
progress_msg = f"[{loss_type.upper()}] Step {i+1}/{total_steps} ({(i+1)/total_steps*100:.1f}%)"
print(progress_msg)
progress(current_progress, progress_msg)
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# Move latent_model_input to correct dtype
latent_model_input = latent_model_input.to(dtype=pipe.unet.dtype)
with torch.no_grad():
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings
)["sample"]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Apply loss every 5 steps if not 'none'
if loss_type != 'none' and i % 5 == 0:
latents = latents.detach().requires_grad_()
latents_x0 = latents - sigma * noise_pred
# Decode to image space for loss computation
with torch.set_grad_enabled(True): # Enable gradients for loss computation
denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
denoised_images = denoised_images.requires_grad_() # Enable gradients for images
loss = image_loss(denoised_images, loss_type, device, elastic_transformer)
# Ensure latents_x0 requires grad
if not latents_x0.requires_grad:
latents_x0 = latents_x0.requires_grad_()
cond_grad = torch.autograd.grad(loss * loss_scale, latents_x0)[0]
latents = latents.detach() - cond_grad * sigma**2
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Clear CUDA cache more efficiently
if torch.cuda.is_available() and i % 10 == 0:
torch.cuda.empty_cache()
# Remove the nested diffusion loop and move finalization outside
progress(idx/len(loss_functions), f"Finalizing {loss_type} image...")
# Proper latent to image conversion
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = pipe.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")
pil_image = Image.fromarray(image[0])
# Add image with its label
all_images.append((pil_image, f"{loss_type.capitalize()} Loss"))
except Exception as e:
print(f"Error generating {loss_type} image: {e}")
continue
# At the end of the function, outside the loop
try:
if len(all_images) == 0:
raise Exception("No images were generated successfully")
print("\nAll images generated successfully!")
return [img for img, _ in all_images]
except Exception as e:
print(f"Error in generate_images: {e}")
return None
def create_interface():
default_prompts = [
"A realistic image of Boy with a cowboy hat in the style of",
"A realistic image of Rabbit in a spacesuit in the style of",
"A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of"
]
concepts = [
"dreams",
"midjourney-style",
"moebius",
"marc-allante",
"wlop"
]
interface = gr.Interface(
fn=generate_images,
inputs=[
gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True),
gr.Dropdown(choices=concepts, label="Select SD Concept")
],
outputs=gr.Gallery(
label="Generated Images (From Left to Right: Original, Blue Channel, Elastic, Symmetry, Saturation)",
show_label=True,
elem_id="gallery",
columns=5,
rows=1,
height="auto"
),
title="Stable Diffusion using Text Inversion",
description="""Generate images using Stable Diffusion with different style concepts. The gallery shows 5 images in this order:
1. Left-most: Original Image (No Loss) - Base generation without modifications
2. Second: Blue Channel Loss - Enhanced blue tones for atmospheric effects
3. Middle: Elastic Loss - Added elastic deformation for artistic distortion
4. Fourth: Symmetry Loss - Enforced symmetrical features
5. Right-most: Saturation Loss - Modified color saturation for vibrant effects
Note: Image generation may take several minutes. Please be patient while the images are being processed.""",
cache_examples=False,
max_batch_size=1,
flagging_mode="never"
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.queue(max_size=5) # Remove concurrency_count parameter
interface.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)