import gradio as gr import torch import gc from PIL import Image import torchvision.transforms as T import torch.nn.functional as F from diffusers import DiffusionPipeline, LMSDiscreteScheduler # Initialize model and configurations # At the top level, add global variables pipe = None device = None elastic_transformer = None def init_model(): global pipe, device if pipe is not None: return pipe, device torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32 pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype ).to(torch_device) # Load SD concepts concepts = { "dreams": "sd-concepts-library/dreams", "midjourney-style": "sd-concepts-library/midjourney-style", "moebius": "sd-concepts-library/moebius", "marc-allante": "sd-concepts-library/style-of-marc-allante", "wlop": "sd-concepts-library/wlop-style" } for concept in concepts.values(): pipe.load_textual_inversion(concept, mean_resizing=False) device = torch_device return pipe, device def init_transformers(device): global elastic_transformer if elastic_transformer is not None: return elastic_transformer elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device) return elastic_transformer # Add after init_transformers and before generate_images def image_loss(images, loss_type, device, elastic_transformer): if loss_type == 'blue': # Reduced target blue value from 0.9 to 0.6 for more subtle effect error = torch.abs(images[:,2] - 0.6).mean() # Apply a lower scale specifically for blue loss return (error * 0.3).to(device) # Reduced scaling factor elif loss_type == 'elastic': transformed_imgs = elastic_transformer(images) error = torch.abs(transformed_imgs - images).mean() return error.to(device) elif loss_type == 'symmetry': flipped_image = torch.flip(images, [3]) error = F.mse_loss(images, flipped_image) return error.to(device) elif loss_type == 'saturation': transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) error = torch.abs(transformed_imgs - images).mean() return error.to(device) else: return torch.tensor(0.0).to(device) # Update configuration for faster generation height, width = 384, 384 # Reduced from 512x512 to 384x384 guidance_scale = 8 # Increased from 7.5 to 8 for better prompt adherence num_inference_steps = 45 # Using 45 steps for better quality loss_scale = 150 def generate_images(prompt, concept): global pipe, device, elastic_transformer if pipe is None: pipe, device = init_model() if elastic_transformer is None: elastic_transformer = init_transformers(device) # Create prompt text and initialize results prompt_text = f"{prompt} {concept}" all_images = [] # Changed from images to all_images # Process each loss type loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation'] progress = gr.Progress() # Generate one random seed for all loss types random_seed = torch.randint(1, 10000, (1,)).item() # Random seed between 1 and 9999 print(f"\nUsing random seed {random_seed} for all images") for idx, loss_type in enumerate(loss_functions): try: print(f"\n[{loss_type.upper()}] Starting image generation...") progress(idx/len(loss_functions), f"Starting {loss_type} image generation...") # Better memory management if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() torch.cuda.empty_cache() # Move inputs to correct device and dtype # Remove incorrect device movement # text_input = text_input.to(device) # Remove this line # uncond_input = uncond_input.to(device) # Remove this line # latents = latents.to(dtype=pipe.vae.dtype, device=device) # Remove this line # Initialize scheduler and process text first scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(device) # Process text embeddings text_input = pipe.tokenizer( [prompt_text], padding='max_length', max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt" ) with torch.no_grad(): text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0] uncond_input = pipe.tokenizer( [""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # Generate initial latents with random seed # Use the same seed for all loss types generator = torch.manual_seed(random_seed) latents = torch.randn( (1, pipe.unet.config.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(device=device, dtype=pipe.unet.dtype) latents = latents * scheduler.init_noise_sigma # Diffusion process total_steps = len(scheduler.timesteps) for i, t in enumerate(scheduler.timesteps): current_progress = (idx + (i / total_steps)) / len(loss_functions) progress_msg = f"[{loss_type.upper()}] Step {i+1}/{total_steps} ({(i+1)/total_steps*100:.1f}%)" print(progress_msg) progress(current_progress, progress_msg) latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) # Move latent_model_input to correct dtype latent_model_input = latent_model_input.to(dtype=pipe.unet.dtype) with torch.no_grad(): noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=text_embeddings )["sample"] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Apply loss every 5 steps if not 'none' if loss_type != 'none' and i % 5 == 0: latents = latents.detach().requires_grad_() latents_x0 = latents - sigma * noise_pred # Decode to image space for loss computation with torch.set_grad_enabled(True): # Enable gradients for loss computation denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 denoised_images = denoised_images.requires_grad_() # Enable gradients for images loss = image_loss(denoised_images, loss_type, device, elastic_transformer) # Ensure latents_x0 requires grad if not latents_x0.requires_grad: latents_x0 = latents_x0.requires_grad_() cond_grad = torch.autograd.grad(loss * loss_scale, latents_x0)[0] latents = latents.detach() - cond_grad * sigma**2 latents = scheduler.step(noise_pred, t, latents).prev_sample # Clear CUDA cache more efficiently if torch.cuda.is_available() and i % 10 == 0: torch.cuda.empty_cache() # Remove the nested diffusion loop and move finalization outside progress(idx/len(loss_functions), f"Finalizing {loss_type} image...") # Proper latent to image conversion latents = (1 / 0.18215) * latents with torch.no_grad(): image = pipe.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() image = (image * 255).round().astype("uint8") pil_image = Image.fromarray(image[0]) # Add image with its label all_images.append((pil_image, f"{loss_type.capitalize()} Loss")) except Exception as e: print(f"Error generating {loss_type} image: {e}") continue # At the end of the function, outside the loop try: if len(all_images) == 0: raise Exception("No images were generated successfully") print("\nAll images generated successfully!") return [img for img, _ in all_images] except Exception as e: print(f"Error in generate_images: {e}") return None def create_interface(): default_prompts = [ "A realistic image of Boy with a cowboy hat in the style of", "A realistic image of Rabbit in a spacesuit in the style of", "A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of" ] concepts = [ "dreams", "midjourney-style", "moebius", "marc-allante", "wlop" ] interface = gr.Interface( fn=generate_images, inputs=[ gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True), gr.Dropdown(choices=concepts, label="Select SD Concept") ], outputs=gr.Gallery( label="Generated Images (From Left to Right: Original, Blue Channel, Elastic, Symmetry, Saturation)", show_label=True, elem_id="gallery", columns=5, rows=1, height="auto" ), title="Stable Diffusion using Text Inversion", description="""Generate images using Stable Diffusion with different style concepts. The gallery shows 5 images in this order: 1. Left-most: Original Image (No Loss) - Base generation without modifications 2. Second: Blue Channel Loss - Enhanced blue tones for atmospheric effects 3. Middle: Elastic Loss - Added elastic deformation for artistic distortion 4. Fourth: Symmetry Loss - Enforced symmetrical features 5. Right-most: Saturation Loss - Modified color saturation for vibrant effects Note: Image generation may take several minutes. Please be patient while the images are being processed.""", cache_examples=False, max_batch_size=1, flagging_mode="never" ) return interface if __name__ == "__main__": interface = create_interface() interface.queue(max_size=5) # Remove concurrency_count parameter interface.launch( share=True, server_name="0.0.0.0", server_port=7860 )