import gradio as gr import torch import gc from PIL import Image import torchvision.transforms as T import torch.nn.functional as F from diffusers import DiffusionPipeline, LMSDiscreteScheduler # Initialize model and configurations # At the top level, add global variables pipe = None device = None elastic_transformer = None def init_model(): global pipe, device if pipe is not None: return pipe, device torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32 pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype ).to(torch_device) # Load SD concepts concepts = { "dreams": "sd-concepts-library/dreams", "midjourney-style": "sd-concepts-library/midjourney-style", "moebius": "sd-concepts-library/moebius", "marc-allante": "sd-concepts-library/style-of-marc-allante", "wlop": "sd-concepts-library/wlop-style" } for concept in concepts.values(): pipe.load_textual_inversion(concept, mean_resizing=False) device = torch_device return pipe, device def init_transformers(device): global elastic_transformer if elastic_transformer is not None: return elastic_transformer elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device) return elastic_transformer # Add after init_transformers and before generate_images def image_loss(images, loss_type, device, elastic_transformer): if loss_type == 'blue': error = torch.abs(images[:,2] - 0.9).mean() return error.to(device) elif loss_type == 'elastic': transformed_imgs = elastic_transformer(images) error = torch.abs(transformed_imgs - images).mean() return error.to(device) elif loss_type == 'symmetry': flipped_image = torch.flip(images, [3]) error = F.mse_loss(images, flipped_image) return error.to(device) elif loss_type == 'saturation': transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) error = torch.abs(transformed_imgs - images).mean() return error.to(device) else: return torch.tensor(0.0).to(device) def generate_images(prompt, concept): global pipe, device, elastic_transformer if pipe is None: pipe, device = init_model() if elastic_transformer is None: elastic_transformer = init_transformers(device) # Configuration height, width = 384, 384 guidance_scale = 8 num_inference_steps = 45 loss_scale = 10.0 # Create scheduler scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) pipe.scheduler = scheduler # Set the scheduler # Create prompt text prompt_text = f"{prompt} {concept}" # Predefined seeds for each loss function seeds = { 'none': 42, 'blue': 123, 'elastic': 456, 'symmetry': 789, 'saturation': 1000 } loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation'] images = [] progress = gr.Progress() # Generate image for each loss function for idx, loss_type in enumerate(loss_functions): progress(idx/len(loss_functions), f"Generating {loss_type} image...") generator = torch.manual_seed(seeds[loss_type]) # Generate base image try: output = pipe( prompt_text, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator ) except Exception as e: print(f"Error generating image: {e}") return None # Apply loss function if not 'none' if loss_type != 'none': try: # Convert PIL image to tensor and move to device image_tensor = T.ToTensor()(output.images[0]).unsqueeze(0).to(device) # Apply loss and update image loss = image_loss(image_tensor, loss_type, device, elastic_transformer) image_tensor = image_tensor - loss_scale * loss # Move back to CPU and convert to PIL image = T.ToPILImage()(image_tensor.cpu().squeeze(0).clamp(0, 1)) except Exception as e: print(f"Error applying {loss_type} loss: {e}") image = output.images[0] # Use original image if loss fails else: image = output.images[0] # Add image with its label try: # Ensure image is in correct format (PIL.Image) if not isinstance(image, Image.Image): print(f"Warning: Converting {loss_type} image to PIL format") image = Image.fromarray(image) # Add tuple of (image, label) to list images.append((image, f"{loss_type.capitalize()} Loss")) print(f"Added {loss_type} image to gallery") # Debug print except Exception as e: print(f"Error adding {loss_type} image to gallery: {e}") continue # Clear GPU memory after each image if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Return all generated images print(f"Returning {len(images)} images") if not images: return None return images def create_interface(): default_prompts = [ "A realistic image of Boy with a cowboy hat in the style of", "A realistic image of Rabbit in a spacesuit in the style of", "A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of" ] concepts = [ "dreams", "midjourney-style", "moebius", "marc-allante", "wlop" ] interface = gr.Interface( fn=generate_images, inputs=[ gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True), gr.Dropdown(choices=concepts, label="Select SD Concept") ], outputs=gr.Gallery( label="Generated Images (From Left to Right: Original, Blue Loss, Elastic Loss, Symmetry Loss, Saturation Loss)", show_label=True, elem_id="gallery", columns=5, rows=1, height=512, object_fit="contain" ), # Simplified Gallery definition title="Stable Diffusion using Text Inversion", description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side: 1. Original Image (No Loss) 2. Blue Channel Loss - Enhances blue tones 3. Elastic Loss - Adds elastic deformation 4. Symmetry Loss - Enforces symmetrical features 5. Saturation Loss - Modifies color saturation Note: Image generation may take several minutes. Please be patient while the images are being processed.""", flagging_mode="never" # Updated from allow_flagging ) return interface if __name__ == "__main__": interface = create_interface() interface.queue(max_size=5) # Simplified queue configuration interface.launch( share=True, server_name="0.0.0.0", max_threads=1 )