MilindChawre commited on
Commit
45b110b
·
1 Parent(s): a240ddb

Adding code for stable diffusion using text inversion

Browse files
Files changed (3) hide show
  1. README.md +55 -1
  2. app.py +219 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -10,4 +10,58 @@ pinned: false
10
  short_description: Stable Diffusion using Text Inversion
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  short_description: Stable Diffusion using Text Inversion
11
  ---
12
 
13
+ # Stable Diffusion using Text Inversion
14
+
15
+ A Gradio web application that generates images using Stable Diffusion with various text inversion concepts and loss functions.
16
+
17
+ ## Features
18
+
19
+ - Generate images using Stable Diffusion v1.4
20
+ - Apply different artistic styles using text inversion concepts:
21
+ - Dreams
22
+ - Midjourney Style
23
+ - Moebius
24
+ - Marc Allante
25
+ - WLOP
26
+ - Automatic application of multiple loss functions:
27
+ - No Loss (Base generation)
28
+ - Blue Channel Loss
29
+ - Elastic Transformation Loss
30
+ - Symmetry Loss
31
+ - Saturation Loss
32
+ - User-friendly interface with preset prompts and custom prompt input
33
+ - Side-by-side comparison of different loss function effects
34
+
35
+ ## Usage
36
+
37
+ 1. Select a preset prompt or enter your own custom prompt
38
+ 2. Choose a style concept from the dropdown menu
39
+ 3. Click "Submit" to generate images
40
+ 4. View the results showing different loss function effects side by side
41
+
42
+ ## Installation
43
+
44
+ 1. Clone this repository
45
+ 2. Install the required dependencies:
46
+ ```bash
47
+ pip install -r requirements.txt
48
+ ```
49
+ 3. Run the application:
50
+ ```bash
51
+ python app.py
52
+ ```
53
+
54
+ ## Requirements
55
+ - Python 3.7+
56
+ - PyTorch
57
+ - Diffusers
58
+ - Transformers
59
+ - Gradio
60
+ - Torchvision
61
+ - PIL
62
+
63
+ ## Model Details
64
+ The application uses the CompVis/stable-diffusion-v1-4 model with text inversion concepts from the Hugging Face SD concepts library. The model runs with float16 precision and supports CUDA, MPS, and CPU devices.
65
+
66
+ ## License
67
+ This project uses the CompVis/stable-diffusion-v1-4 model which is subject to the CreativeML Open RAIL-M license.
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import gc
4
+ from PIL import Image
5
+ import torchvision.transforms as T
6
+ import torch.nn.functional as F
7
+ from diffusers import DiffusionPipeline, LMSDiscreteScheduler
8
+
9
+ # Initialize model and configurations
10
+ # At the top level, add global variables
11
+ pipe = None
12
+ device = None
13
+ elastic_transformer = None
14
+
15
+ def init_model():
16
+ global pipe, device
17
+ if pipe is not None:
18
+ return pipe, device
19
+
20
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
21
+ torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
22
+
23
+ pipe = DiffusionPipeline.from_pretrained(
24
+ "CompVis/stable-diffusion-v1-4",
25
+ torch_dtype=torch_dtype
26
+ ).to(torch_device)
27
+
28
+ # Load SD concepts
29
+ concepts = {
30
+ "dreams": "sd-concepts-library/dreams",
31
+ "midjourney-style": "sd-concepts-library/midjourney-style",
32
+ "moebius": "sd-concepts-library/moebius",
33
+ "marc-allante": "sd-concepts-library/style-of-marc-allante",
34
+ "wlop": "sd-concepts-library/wlop-style"
35
+ }
36
+
37
+ for concept in concepts.values():
38
+ pipe.load_textual_inversion(concept, mean_resizing=False)
39
+
40
+ device = torch_device
41
+ return pipe, device
42
+
43
+ def init_transformers(device):
44
+ global elastic_transformer
45
+ if elastic_transformer is not None:
46
+ return elastic_transformer
47
+ elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0).to(device)
48
+ return elastic_transformer
49
+
50
+ # Add after init_transformers and before generate_images
51
+ def image_loss(images, loss_type, device, elastic_transformer):
52
+ if loss_type == 'blue':
53
+ error = torch.abs(images[:,2] - 0.9).mean()
54
+ return error.to(device)
55
+ elif loss_type == 'elastic':
56
+ transformed_imgs = elastic_transformer(images)
57
+ error = torch.abs(transformed_imgs - images).mean()
58
+ return error.to(device)
59
+ elif loss_type == 'symmetry':
60
+ flipped_image = torch.flip(images, [3])
61
+ error = F.mse_loss(images, flipped_image)
62
+ return error.to(device)
63
+ elif loss_type == 'saturation':
64
+ transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10)
65
+ error = torch.abs(transformed_imgs - images).mean()
66
+ return error.to(device)
67
+ else:
68
+ return torch.tensor(0.0).to(device)
69
+
70
+ def generate_images(prompt, concept):
71
+ global pipe, device, elastic_transformer
72
+ if pipe is None:
73
+ pipe, device = init_model()
74
+ if elastic_transformer is None:
75
+ elastic_transformer = init_transformers(device)
76
+
77
+ # Configuration
78
+ height, width = 384, 384
79
+ guidance_scale = 8
80
+ num_inference_steps = 45
81
+ loss_scale = 10.0
82
+
83
+ # Create scheduler
84
+ scheduler = LMSDiscreteScheduler(
85
+ beta_start=0.00085,
86
+ beta_end=0.012,
87
+ beta_schedule="scaled_linear",
88
+ num_train_timesteps=1000
89
+ )
90
+ pipe.scheduler = scheduler # Set the scheduler
91
+
92
+ # Create prompt text
93
+ prompt_text = f"{prompt} {concept}"
94
+
95
+ # Predefined seeds for each loss function
96
+ seeds = {
97
+ 'none': 42,
98
+ 'blue': 123,
99
+ 'elastic': 456,
100
+ 'symmetry': 789,
101
+ 'saturation': 1000
102
+ }
103
+
104
+ loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
105
+ images = []
106
+ progress = gr.Progress()
107
+
108
+ # Generate image for each loss function
109
+ for idx, loss_type in enumerate(loss_functions):
110
+ progress(idx/len(loss_functions), f"Generating {loss_type} image...")
111
+ generator = torch.manual_seed(seeds[loss_type])
112
+
113
+ # Generate base image
114
+ try:
115
+ output = pipe(
116
+ prompt_text,
117
+ height=height,
118
+ width=width,
119
+ num_inference_steps=num_inference_steps,
120
+ guidance_scale=guidance_scale,
121
+ generator=generator
122
+ )
123
+ except Exception as e:
124
+ print(f"Error generating image: {e}")
125
+ return None
126
+
127
+ # Apply loss function if not 'none'
128
+ if loss_type != 'none':
129
+ try:
130
+ # Convert PIL image to tensor and move to device
131
+ image_tensor = T.ToTensor()(output.images[0]).unsqueeze(0).to(device)
132
+ # Apply loss and update image
133
+ loss = image_loss(image_tensor, loss_type, device, elastic_transformer)
134
+ image_tensor = image_tensor - loss_scale * loss
135
+ # Move back to CPU and convert to PIL
136
+ image = T.ToPILImage()(image_tensor.cpu().squeeze(0).clamp(0, 1))
137
+ except Exception as e:
138
+ print(f"Error applying {loss_type} loss: {e}")
139
+ image = output.images[0] # Use original image if loss fails
140
+ else:
141
+ image = output.images[0]
142
+
143
+ # Add image with its label
144
+ try:
145
+ # Ensure image is in correct format (PIL.Image)
146
+ if not isinstance(image, Image.Image):
147
+ print(f"Warning: Converting {loss_type} image to PIL format")
148
+ image = Image.fromarray(image)
149
+
150
+ # Add tuple of (image, label) to list
151
+ images.append((image, f"{loss_type.capitalize()} Loss"))
152
+ print(f"Added {loss_type} image to gallery") # Debug print
153
+ except Exception as e:
154
+ print(f"Error adding {loss_type} image to gallery: {e}")
155
+ continue
156
+
157
+ # Clear GPU memory after each image
158
+ if torch.cuda.is_available():
159
+ torch.cuda.empty_cache()
160
+ gc.collect()
161
+
162
+ # Return all generated images
163
+ print(f"Returning {len(images)} images")
164
+ if not images:
165
+ return None
166
+ return images
167
+
168
+ def create_interface():
169
+ default_prompts = [
170
+ "A realistic image of Boy with a cowboy hat in the style of",
171
+ "A realistic image of Rabbit in a spacesuit in the style of",
172
+ "A rugged soldier in full combat gear, standing on a battlefield at dusk, dramatic lighting, highly detailed, cinematic style in the style of"
173
+ ]
174
+
175
+ concepts = [
176
+ "dreams",
177
+ "midjourney-style",
178
+ "moebius",
179
+ "marc-allante",
180
+ "wlop"
181
+ ]
182
+
183
+ interface = gr.Interface(
184
+ fn=generate_images,
185
+ inputs=[
186
+ gr.Dropdown(choices=default_prompts, label="Select a preset prompt or type your own", allow_custom_value=True),
187
+ gr.Dropdown(choices=concepts, label="Select SD Concept")
188
+ ],
189
+ outputs=gr.Gallery(
190
+ label="Generated Images (From Left to Right: Original, Blue Loss, Elastic Loss, Symmetry Loss, Saturation Loss)",
191
+ show_label=True,
192
+ elem_id="gallery",
193
+ columns=5,
194
+ rows=1,
195
+ height=512,
196
+ object_fit="contain"
197
+ ), # Simplified Gallery definition
198
+ title="Stable Diffusion using Text Inversion",
199
+ description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side:
200
+ 1. Original Image (No Loss)
201
+ 2. Blue Channel Loss - Enhances blue tones
202
+ 3. Elastic Loss - Adds elastic deformation
203
+ 4. Symmetry Loss - Enforces symmetrical features
204
+ 5. Saturation Loss - Modifies color saturation
205
+
206
+ Note: Image generation may take several minutes. Please be patient while the images are being processed.""",
207
+ flagging_mode="never" # Updated from allow_flagging
208
+ )
209
+
210
+ return interface
211
+
212
+ if __name__ == "__main__":
213
+ interface = create_interface()
214
+ interface.queue(max_size=5) # Simplified queue configuration
215
+ interface.launch(
216
+ share=True,
217
+ server_name="0.0.0.0",
218
+ max_threads=1
219
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ transformers
4
+ gradio
5
+ torchvision
6
+ Pillow
7
+ scipy