File size: 7,753 Bytes
45b110b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import gradio as gr
import torch
import gc
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
# Initialize model and configurations
# At the top level, add global variables
pipe = None
device = None
elastic_transformer = None
def init_model():
global pipe, device
if pipe is not None:
return pipe, device
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch_dtype
).to(torch_device)
# Load SD concepts
concepts = {
"dreams": "sd-concepts-library/dreams",
"midjourney-style": "sd-concepts-library/midjourney-style",
"moebius": "sd-concepts-library/moebius",
"marc-allante": "sd-concepts-library/style-of-marc-allante",
"wlop": "sd-concepts-library/wlop-style"
}
for concept in concepts.values():
pipe.load_textual_inversion(concept, mean_resizing=False)
device = torch_device
return pipe, device
def init_transformers(device):
global elastic_transformer
if elastic_transformer is not None:
return elastic_transformer
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device)
return elastic_transformer
# Add after init_transformers and before generate_images
def image_loss(images, loss_type, device, elastic_transformer):
if loss_type == 'blue':
error = torch.abs(images[:,2] - 0.9).mean()
return error.to(device)
elif loss_type == 'elastic':
transformed_imgs = elastic_transformer(images)
error = torch.abs(transformed_imgs - images).mean()
return error.to(device)
elif loss_type == 'symmetry':
flipped_image = torch.flip(images, [3])
error = F.mse_loss(images, flipped_image)
return error.to(device)
elif loss_type == 'saturation':
transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10)
error = torch.abs(transformed_imgs - images).mean()
return error.to(device)
else:
return torch.tensor(0.0).to(device)
def generate_images(prompt, concept):
global pipe, device, elastic_transformer
if pipe is None:
pipe, device = init_model()
if elastic_transformer is None:
elastic_transformer = init_transformers(device)
# Configuration
height, width = 384, 384
guidance_scale = 8
num_inference_steps = 45
loss_scale = 10.0
# Create scheduler
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000
)
pipe.scheduler = scheduler # Set the scheduler
# Create prompt text
prompt_text = f"{prompt} {concept}"
# Predefined seeds for each loss function
seeds = {
'none': 42,
'blue': 123,
'elastic': 456,
'symmetry': 789,
'saturation': 1000
}
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
images = []
progress = gr.Progress()
# Generate image for each loss function
for idx, loss_type in enumerate(loss_functions):
progress(idx/len(loss_functions), f"Generating {loss_type} image...")
generator = torch.manual_seed(seeds[loss_type])
# Generate base image
try:
output = pipe(
prompt_text,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)
except Exception as e:
print(f"Error generating image: {e}")
return None
# Apply loss function if not 'none'
if loss_type != 'none':
try:
# Convert PIL image to tensor and move to device
image_tensor = T.ToTensor()(output.images[0]).unsqueeze(0).to(device)
# Apply loss and update image
loss = image_loss(image_tensor, loss_type, device, elastic_transformer)
image_tensor = image_tensor - loss_scale * loss
# Move back to CPU and convert to PIL
image = T.ToPILImage()(image_tensor.cpu().squeeze(0).clamp(0, 1))
except Exception as e:
print(f"Error applying {loss_type} loss: {e}")
image = output.images[0] # Use original image if loss fails
else:
image = output.images[0]
# Add image with its label
try:
# Ensure image is in correct format (PIL.Image)
if not isinstance(image, Image.Image):
print(f"Warning: Converting {loss_type} image to PIL format")
image = Image.fromarray(image)
# Add tuple of (image, label) to list
images.append((image, f"{loss_type.capitalize()} Loss"))
print(f"Added {loss_type} image to gallery") # Debug print
except Exception as e:
print(f"Error adding {loss_type} image to gallery: {e}")
continue
# Clear GPU memory after each image
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Return all generated images
print(f"Returning {len(images)} images")
if not images:
return None
return images
def create_interface():
default_prompts = [
"A realistic image of Boy with a cowboy hat in the style of",
"A realistic image of Rabbit in a spacesuit in the style of",
"A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of"
]
concepts = [
"dreams",
"midjourney-style",
"moebius",
"marc-allante",
"wlop"
]
interface = gr.Interface(
fn=generate_images,
inputs=[
gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True),
gr.Dropdown(choices=concepts, label="Select SD Concept")
],
outputs=gr.Gallery(
label="Generated Images (From Left to Right: Original, Blue Loss, Elastic Loss, Symmetry Loss, Saturation Loss)",
show_label=True,
elem_id="gallery",
columns=5,
rows=1,
height=512,
object_fit="contain"
), # Simplified Gallery definition
title="Stable Diffusion using Text Inversion",
description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side:
1. Original Image (No Loss)
2. Blue Channel Loss - Enhances blue tones
3. Elastic Loss - Adds elastic deformation
4. Symmetry Loss - Enforces symmetrical features
5. Saturation Loss - Modifies color saturation
Note: Image generation may take several minutes. Please be patient while the images are being processed.""",
flagging_mode="never" # Updated from allow_flagging
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.queue(max_size=5) # Simplified queue configuration
interface.launch(
share=True,
server_name="0.0.0.0",
max_threads=1
) |