|
import gradio as gr |
|
import torch |
|
import gc |
|
from PIL import Image |
|
import torchvision.transforms as T |
|
import torch.nn.functional as F |
|
from diffusers import DiffusionPipeline, LMSDiscreteScheduler |
|
|
|
|
|
|
|
pipe = None |
|
device = None |
|
elastic_transformer = None |
|
|
|
def init_model(): |
|
global pipe, device |
|
if pipe is not None: |
|
return pipe, device |
|
|
|
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32 |
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"CompVis/stable-diffusion-v1-4", |
|
torch_dtype=torch_dtype |
|
).to(torch_device) |
|
|
|
|
|
concepts = { |
|
"dreams": "sd-concepts-library/dreams", |
|
"midjourney-style": "sd-concepts-library/midjourney-style", |
|
"moebius": "sd-concepts-library/moebius", |
|
"marc-allante": "sd-concepts-library/style-of-marc-allante", |
|
"wlop": "sd-concepts-library/wlop-style" |
|
} |
|
|
|
for concept in concepts.values(): |
|
pipe.load_textual_inversion(concept, mean_resizing=False) |
|
|
|
device = torch_device |
|
return pipe, device |
|
|
|
def init_transformers(device): |
|
global elastic_transformer |
|
if elastic_transformer is not None: |
|
return elastic_transformer |
|
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device) |
|
return elastic_transformer |
|
|
|
|
|
def image_loss(images, loss_type, device, elastic_transformer): |
|
if loss_type == 'blue': |
|
error = torch.abs(images[:,2] - 0.9).mean() |
|
return error.to(device) |
|
elif loss_type == 'elastic': |
|
transformed_imgs = elastic_transformer(images) |
|
error = torch.abs(transformed_imgs - images).mean() |
|
return error.to(device) |
|
elif loss_type == 'symmetry': |
|
flipped_image = torch.flip(images, [3]) |
|
error = F.mse_loss(images, flipped_image) |
|
return error.to(device) |
|
elif loss_type == 'saturation': |
|
transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) |
|
error = torch.abs(transformed_imgs - images).mean() |
|
return error.to(device) |
|
else: |
|
return torch.tensor(0.0).to(device) |
|
|
|
|
|
height, width = 384, 384 |
|
guidance_scale = 7.5 |
|
num_inference_steps = 30 |
|
loss_scale = 150 |
|
|
|
def generate_images(prompt, concept): |
|
global pipe, device, elastic_transformer |
|
if pipe is None: |
|
pipe, device = init_model() |
|
if elastic_transformer is None: |
|
elastic_transformer = init_transformers(device) |
|
|
|
|
|
prompt_text = f"{prompt} {concept}" |
|
all_images = [] |
|
|
|
|
|
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation'] |
|
progress = gr.Progress() |
|
|
|
for idx, loss_type in enumerate(loss_functions): |
|
try: |
|
|
|
progress(idx/len(loss_functions), f"Starting {loss_type} image generation...") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scheduler = LMSDiscreteScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
num_train_timesteps=1000 |
|
) |
|
scheduler.set_timesteps(num_inference_steps) |
|
scheduler.timesteps = scheduler.timesteps.to(device) |
|
|
|
|
|
text_input = pipe.tokenizer( |
|
[prompt_text], |
|
padding='max_length', |
|
max_length=pipe.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
uncond_input = pipe.tokenizer( |
|
[""] * 1, |
|
padding="max_length", |
|
max_length=text_input.input_ids.shape[-1], |
|
return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
|
|
generator = torch.manual_seed(idx * 1000) |
|
latents = torch.randn( |
|
(1, pipe.unet.config.in_channels, height // 8, width // 8), |
|
generator=generator, |
|
) |
|
latents = latents.to(device=device, dtype=pipe.unet.dtype) |
|
latents = latents * scheduler.init_noise_sigma |
|
|
|
|
|
for i, t in enumerate(scheduler.timesteps): |
|
latent_model_input = torch.cat([latents] * 2) |
|
sigma = scheduler.sigmas[i] |
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
latent_model_input = latent_model_input.to(dtype=pipe.unet.dtype) |
|
|
|
with torch.no_grad(): |
|
noise_pred = pipe.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=text_embeddings |
|
)["sample"] |
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
if loss_type != 'none' and i % 5 == 0: |
|
latents = latents.detach().requires_grad_() |
|
latents_x0 = latents - sigma * noise_pred |
|
|
|
|
|
with torch.set_grad_enabled(True): |
|
denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 |
|
denoised_images = denoised_images.requires_grad_() |
|
loss = image_loss(denoised_images, loss_type, device, elastic_transformer) |
|
cond_grad = torch.autograd.grad(loss * loss_scale, latents)[0] |
|
|
|
latents = latents.detach() - cond_grad * sigma**2 |
|
|
|
|
|
for i, t in enumerate(scheduler.timesteps): |
|
current_progress = (idx + (i / len(scheduler.timesteps))) / len(loss_functions) |
|
progress(current_progress, f"Generating {loss_type} image: Step {i+1}/{len(scheduler.timesteps)}") |
|
|
|
|
|
if loss_type != 'none' and i % 8 == 0: |
|
with torch.set_grad_enabled(True): |
|
|
|
denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 |
|
denoised_images = denoised_images.requires_grad_() |
|
loss = image_loss(denoised_images, loss_type, device, elastic_transformer) |
|
cond_grad = torch.autograd.grad(loss * loss_scale, latents)[0] |
|
|
|
latents = latents.detach() - cond_grad * sigma**2 |
|
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
if torch.cuda.is_available() and i % 10 == 0: |
|
torch.cuda.empty_cache() |
|
|
|
progress(idx/len(loss_functions), f"Finalizing {loss_type} image...") |
|
|
|
|
|
latents = (1 / 0.18215) * latents |
|
with torch.no_grad(): |
|
image = pipe.vae.decode(latents).sample |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() |
|
image = (image * 255).round().astype("uint8") |
|
pil_image = Image.fromarray(image[0]) |
|
|
|
|
|
all_images.append((pil_image, f"{loss_type.capitalize()} Loss")) |
|
|
|
except Exception as e: |
|
print(f"Error generating {loss_type} image: {e}") |
|
continue |
|
|
|
|
|
try: |
|
if len(all_images) == 0: |
|
raise Exception("No images were generated successfully") |
|
return [img for img, _ in all_images] |
|
except Exception as e: |
|
print(f"Error in generate_images: {e}") |
|
return None |
|
|
|
def create_interface(): |
|
default_prompts = [ |
|
"A realistic image of Boy with a cowboy hat in the style of", |
|
"A realistic image of Rabbit in a spacesuit in the style of", |
|
"A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of" |
|
] |
|
|
|
concepts = [ |
|
"dreams", |
|
"midjourney-style", |
|
"moebius", |
|
"marc-allante", |
|
"wlop" |
|
] |
|
|
|
interface = gr.Interface( |
|
fn=generate_images, |
|
inputs=[ |
|
gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True), |
|
gr.Dropdown(choices=concepts, label="Select SD Concept") |
|
], |
|
outputs=gr.Gallery( |
|
label="Generated Images (From Left to Right: Original, Blue Channel, Elastic, Symmetry, Saturation)", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=5, |
|
rows=1, |
|
height="auto" |
|
), |
|
title="Stable Diffusion using Text Inversion", |
|
description="""Generate images using Stable Diffusion with different style concepts. The gallery shows 5 images in this order: |
|
1. Left-most: Original Image (No Loss) - Base generation without modifications |
|
2. Second: Blue Channel Loss - Enhanced blue tones for atmospheric effects |
|
3. Middle: Elastic Loss - Added elastic deformation for artistic distortion |
|
4. Fourth: Symmetry Loss - Enforced symmetrical features |
|
5. Right-most: Saturation Loss - Modified color saturation for vibrant effects |
|
|
|
Note: Image generation may take several minutes. Please be patient while the images are being processed.""", |
|
cache_examples=False, |
|
max_batch_size=1, |
|
flagging_mode="never" |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.queue(max_size=5) |
|
interface.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |