|
import gradio as gr |
|
import torch |
|
import gc |
|
from PIL import Image |
|
import torchvision.transforms as T |
|
import torch.nn.functional as F |
|
from diffusers import DiffusionPipeline, LMSDiscreteScheduler |
|
|
|
|
|
|
|
pipe = None |
|
device = None |
|
elastic_transformer = None |
|
|
|
def init_model(): |
|
global pipe, device |
|
if pipe is not None: |
|
return pipe, device |
|
|
|
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32 |
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"CompVis/stable-diffusion-v1-4", |
|
torch_dtype=torch_dtype |
|
).to(torch_device) |
|
|
|
|
|
concepts = { |
|
"dreams": "sd-concepts-library/dreams", |
|
"midjourney-style": "sd-concepts-library/midjourney-style", |
|
"moebius": "sd-concepts-library/moebius", |
|
"marc-allante": "sd-concepts-library/style-of-marc-allante", |
|
"wlop": "sd-concepts-library/wlop-style" |
|
} |
|
|
|
for concept in concepts.values(): |
|
pipe.load_textual_inversion(concept, mean_resizing=False) |
|
|
|
device = torch_device |
|
return pipe, device |
|
|
|
def init_transformers(device): |
|
global elastic_transformer |
|
if elastic_transformer is not None: |
|
return elastic_transformer |
|
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device) |
|
return elastic_transformer |
|
|
|
|
|
def image_loss(images, loss_type, device, elastic_transformer): |
|
if loss_type == 'blue': |
|
error = torch.abs(images[:,2] - 0.9).mean() |
|
return error.to(device) |
|
elif loss_type == 'elastic': |
|
transformed_imgs = elastic_transformer(images) |
|
error = torch.abs(transformed_imgs - images).mean() |
|
return error.to(device) |
|
elif loss_type == 'symmetry': |
|
flipped_image = torch.flip(images, [3]) |
|
error = F.mse_loss(images, flipped_image) |
|
return error.to(device) |
|
elif loss_type == 'saturation': |
|
transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) |
|
error = torch.abs(transformed_imgs - images).mean() |
|
return error.to(device) |
|
else: |
|
return torch.tensor(0.0).to(device) |
|
|
|
def generate_images(prompt, concept): |
|
global pipe, device, elastic_transformer |
|
if pipe is None: |
|
pipe, device = init_model() |
|
if elastic_transformer is None: |
|
elastic_transformer = init_transformers(device) |
|
|
|
|
|
height, width = 384, 384 |
|
guidance_scale = 8 |
|
num_inference_steps = 45 |
|
loss_scale = 10.0 |
|
|
|
|
|
scheduler = LMSDiscreteScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
num_train_timesteps=1000 |
|
) |
|
pipe.scheduler = scheduler |
|
|
|
|
|
prompt_text = f"{prompt} {concept}" |
|
|
|
|
|
seeds = { |
|
'none': 42, |
|
'blue': 123, |
|
'elastic': 456, |
|
'symmetry': 789, |
|
'saturation': 1000 |
|
} |
|
|
|
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation'] |
|
images = [] |
|
progress = gr.Progress() |
|
|
|
|
|
for idx, loss_type in enumerate(loss_functions): |
|
progress(idx/len(loss_functions), f"Generating {loss_type} image...") |
|
generator = torch.manual_seed(seeds[loss_type]) |
|
|
|
|
|
try: |
|
output = pipe( |
|
prompt_text, |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator |
|
) |
|
except Exception as e: |
|
print(f"Error generating image: {e}") |
|
return None |
|
|
|
|
|
if loss_type != 'none': |
|
try: |
|
|
|
image_tensor = T.ToTensor()(output.images[0]).unsqueeze(0).to(device) |
|
|
|
loss = image_loss(image_tensor, loss_type, device, elastic_transformer) |
|
image_tensor = image_tensor - loss_scale * loss |
|
|
|
image = T.ToPILImage()(image_tensor.cpu().squeeze(0).clamp(0, 1)) |
|
except Exception as e: |
|
print(f"Error applying {loss_type} loss: {e}") |
|
image = output.images[0] |
|
else: |
|
image = output.images[0] |
|
|
|
|
|
try: |
|
|
|
if not isinstance(image, Image.Image): |
|
print(f"Warning: Converting {loss_type} image to PIL format") |
|
image = Image.fromarray(image) |
|
|
|
|
|
images.append((image, f"{loss_type.capitalize()} Loss")) |
|
print(f"Added {loss_type} image to gallery") |
|
except Exception as e: |
|
print(f"Error adding {loss_type} image to gallery: {e}") |
|
continue |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
print(f"Returning {len(images)} images") |
|
if not images: |
|
return None |
|
return images |
|
|
|
def create_interface(): |
|
default_prompts = [ |
|
"A realistic image of Boy with a cowboy hat in the style of", |
|
"A realistic image of Rabbit in a spacesuit in the style of", |
|
"A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of" |
|
] |
|
|
|
concepts = [ |
|
"dreams", |
|
"midjourney-style", |
|
"moebius", |
|
"marc-allante", |
|
"wlop" |
|
] |
|
|
|
interface = gr.Interface( |
|
fn=generate_images, |
|
inputs=[ |
|
gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True), |
|
gr.Dropdown(choices=concepts, label="Select SD Concept") |
|
], |
|
outputs=gr.Gallery( |
|
label="Generated Images (From Left to Right: Original, Blue Loss, Elastic Loss, Symmetry Loss, Saturation Loss)", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=5, |
|
rows=1, |
|
height=512, |
|
object_fit="contain" |
|
), |
|
title="Stable Diffusion using Text Inversion", |
|
description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side: |
|
1. Original Image (No Loss) |
|
2. Blue Channel Loss - Enhances blue tones |
|
3. Elastic Loss - Adds elastic deformation |
|
4. Symmetry Loss - Enforces symmetrical features |
|
5. Saturation Loss - Modifies color saturation |
|
|
|
Note: Image generation may take several minutes. Please be patient while the images are being processed.""", |
|
flagging_mode="never" |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.queue(max_size=5) |
|
interface.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
max_threads=1 |
|
) |