Damian Stewart commited on
Commit
94be4c7
·
1 Parent(s): b58675c

add train seed

Browse files
Files changed (4) hide show
  1. StableDiffuser.py +8 -5
  2. app.py +27 -14
  3. memory_efficiency.py +1 -1
  4. train.py +18 -4
StableDiffuser.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from baukit import TraceDict
5
  from diffusers import StableDiffusionPipeline
6
  from PIL import Image
 
7
  from tqdm.auto import tqdm
8
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
@@ -142,6 +143,7 @@ class StableDiffuser(torch.nn.Module):
142
  pred_x0=False,
143
  trace_args=None,
144
  show_progress=True,
 
145
  **kwargs):
146
 
147
  latents_steps = []
@@ -153,11 +155,12 @@ class StableDiffuser(torch.nn.Module):
153
  if trace_args:
154
  trace = TraceDict(self, **trace_args)
155
 
156
- noise_pred = self.predict_noise(
157
- iteration,
158
- latents,
159
- text_embeddings,
160
- **kwargs)
 
161
 
162
  # compute the previous noisy sample x_t -> x_t-1
163
  output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
 
4
  from baukit import TraceDict
5
  from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
+ from torch.cuda.amp import autocast
8
  from tqdm.auto import tqdm
9
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
10
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
 
143
  pred_x0=False,
144
  trace_args=None,
145
  show_progress=True,
146
+ use_amp=False,
147
  **kwargs):
148
 
149
  latents_steps = []
 
155
  if trace_args:
156
  trace = TraceDict(self, **trace_args)
157
 
158
+ with autocast(enabled=use_amp):
159
+ noise_pred = self.predict_noise(
160
+ iteration,
161
+ latents,
162
+ text_embeddings,
163
+ **kwargs)
164
 
165
  # compute the previous noisy sample x_t -> x_t-1
166
  output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
app.py CHANGED
@@ -191,12 +191,20 @@ class Demo:
191
  label="Learning Rate",
192
  info='Learning rate used to train'
193
  )
 
 
 
 
 
194
 
195
- with gr.Row():
196
- self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=False)
197
- self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
198
- self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
199
- #self.train_use_gradient_checkpointing_input = gr.Checkbox(label="Gradient checkpointing", value=True)
 
 
 
200
 
201
  with gr.Column(scale=1):
202
 
@@ -209,16 +217,13 @@ class Demo:
209
  self.download = gr.Files()
210
 
211
  with gr.Tab("Export") as export_column:
212
-
213
  with gr.Row():
214
-
215
  self.explain_train= gr.Markdown(interactive=False,
216
- value='Export a model to Diffusers format. Please enter the base model and select the editing weights.')
217
 
218
  with gr.Row():
219
 
220
  with gr.Column(scale=3):
221
-
222
  self.base_repo_id_or_path_input_export = gr.Text(
223
  label="Base model",
224
  value="CompVis/stable-diffusion-v1-4",
@@ -272,7 +277,8 @@ class Demo:
272
  self.train_use_adamw8bit_input,
273
  self.train_use_xformers_input,
274
  self.train_use_amp_input,
275
- #self.train_use_gradient_checkpointing_input
 
276
  ],
277
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
278
  )
@@ -287,6 +293,7 @@ class Demo:
287
 
288
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
289
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
 
290
  pbar = gr.Progress(track_tqdm=True)):
291
 
292
  if self.training:
@@ -311,19 +318,25 @@ class Demo:
311
  modules = ".*attn1$"
312
  frozen = []
313
 
314
- randn = torch.randint(1, 10000000, (1,)).item()
 
 
 
 
 
 
 
315
 
316
- save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}_{train_method}_ng{neg_guidance}_lr{lr}_iter{iterations}.pt"
317
  try:
318
  self.training = True
319
  train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
320
- use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing)
321
  finally:
322
  self.training = False
323
 
324
  torch.cuda.empty_cache()
325
 
326
- new_model_name = f'*new* {os.path.basename(save_path)}'
327
  model_map[new_model_name] = save_path
328
 
329
  return [gr.update(interactive=True, value='Train'),
 
191
  label="Learning Rate",
192
  info='Learning rate used to train'
193
  )
194
+ self.train_seed_input = gr.Number(
195
+ value=-1,
196
+ label="Seed",
197
+ info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
198
+ )
199
 
200
+ with gr.Column():
201
+ self.train_memory_options = gr.Markdown(interactive=False,
202
+ r value='Performance and VRAM usage optimizations, may not work on all devices.')
203
+ with gr.Row():
204
+ self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
205
+ self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
206
+ self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
207
+ self.train_use_gradient_checkpointing_input = gr.Checkbox(label="Gradient checkpointing", value=True)
208
 
209
  with gr.Column(scale=1):
210
 
 
217
  self.download = gr.Files()
218
 
219
  with gr.Tab("Export") as export_column:
 
220
  with gr.Row():
 
221
  self.explain_train= gr.Markdown(interactive=False,
222
+ value='Export a model to Diffusers format. Please enter the base model and select the editing weights.')
223
 
224
  with gr.Row():
225
 
226
  with gr.Column(scale=3):
 
227
  self.base_repo_id_or_path_input_export = gr.Text(
228
  label="Base model",
229
  value="CompVis/stable-diffusion-v1-4",
 
277
  self.train_use_adamw8bit_input,
278
  self.train_use_xformers_input,
279
  self.train_use_amp_input,
280
+ self.train_use_gradient_checkpointing_input,
281
+ self.train_seed_input,
282
  ],
283
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
284
  )
 
293
 
294
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
295
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
296
+ seed = -1,
297
  pbar = gr.Progress(track_tqdm=True)):
298
 
299
  if self.training:
 
318
  modules = ".*attn1$"
319
  frozen = []
320
 
321
+ # build a save path, ensure it isn't in use
322
+ while True:
323
+ randn = torch.randint(1, 10000000, (1,)).item()
324
+ options = f'{"a8" if use_adamw8bit else ""}{"AM" if use_amp else ""}{"xf" if use_xformers else ""}{"gc" if use_gradient_checkpointing else ""}'
325
+ save_path = f"models/{prompt.lower().replace(' ', '')}_{train_method}_ng{neg_guidance}_lr{lr}_iter{iterations}_seed{seed}_{options}__{randn}.pt"
326
+ if not os.path.exists(save_path):
327
+ break
328
+ # repeat until a not-in-use path is found
329
 
 
330
  try:
331
  self.training = True
332
  train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
333
+ use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=seed)
334
  finally:
335
  self.training = False
336
 
337
  torch.cuda.empty_cache()
338
 
339
+ new_model_name = f'{os.path.basename(save_path)}'
340
  model_map[new_model_name] = save_path
341
 
342
  return [gr.update(interactive=True, value='Train'),
memory_efficiency.py CHANGED
@@ -37,7 +37,7 @@ class MemoryEfficiencyWrapper:
37
  print("failed to load xformers, using attention slicing instead")
38
  self.diffuser.unet.set_attention_slice("auto")
39
  pass
40
- elif (not self.amp and self.is_sd1attn):
41
  print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
42
  self.diffuser.unet.set_attention_slice("auto")
43
  else:
 
37
  print("failed to load xformers, using attention slicing instead")
38
  self.diffuser.unet.set_attention_slice("auto")
39
  pass
40
+ elif (not self.use_amp and self.is_sd1attn):
41
  print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
42
  self.diffuser.unet.set_attention_slice("auto")
43
  else:
train.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from torch.cuda.amp import autocast
2
 
3
  from StableDiffuser import StableDiffuser
@@ -9,7 +12,7 @@ from memory_efficiency import MemoryEfficiencyWrapper
9
 
10
 
11
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
12
- use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
13
 
14
  nsteps = 50
15
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
@@ -47,6 +50,10 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
47
 
48
  print(f"using img_size of {img_size}")
49
 
 
 
 
 
50
  for i in pbar:
51
  with torch.no_grad():
52
  diffuser.set_scheduler_timesteps(nsteps)
@@ -55,14 +62,15 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
55
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
56
  latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
57
 
58
- with autocast(enabled=use_amp), finetuner:
59
  latents_steps, _ = diffuser.diffusion(
60
  latents,
61
  positive_text_embeddings,
62
  start_iteration=0,
63
  end_iteration=iteration,
64
  guidance_scale=3,
65
- show_progress=False
 
66
  )
67
 
68
  diffuser.set_scheduler_timesteps(1000)
@@ -82,7 +90,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
82
  # loss = criteria(e_n, e_0) works the best try 5000 epochs
83
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
84
  memory_efficiency_wrapper.step(optimizer, loss)
85
- optimizer.step()
86
 
87
  torch.save(finetuner.state_dict(), save_path)
88
 
@@ -104,5 +112,11 @@ if __name__ == '__main__':
104
  parser.add_argument('--iterations', type=int, required=True)
105
  parser.add_argument('--lr', type=float, required=True)
106
  parser.add_argument('--negative_guidance', type=float, required=True)
 
 
 
 
 
 
107
 
108
  train(**vars(parser.parse_args()))
 
1
+ from random import random
2
+
3
+ from accelerate.utils import set_seed
4
  from torch.cuda.amp import autocast
5
 
6
  from StableDiffuser import StableDiffuser
 
12
 
13
 
14
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
15
+ use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1):
16
 
17
  nsteps = 50
18
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
 
50
 
51
  print(f"using img_size of {img_size}")
52
 
53
+ if seed == -1:
54
+ seed = random.randint(0, 2 ** 30)
55
+ set_seed(seed)
56
+
57
  for i in pbar:
58
  with torch.no_grad():
59
  diffuser.set_scheduler_timesteps(nsteps)
 
62
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
63
  latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
64
 
65
+ with finetuner:
66
  latents_steps, _ = diffuser.diffusion(
67
  latents,
68
  positive_text_embeddings,
69
  start_iteration=0,
70
  end_iteration=iteration,
71
  guidance_scale=3,
72
+ show_progress=False,
73
+ use_amp=use_amp
74
  )
75
 
76
  diffuser.set_scheduler_timesteps(1000)
 
90
  # loss = criteria(e_n, e_0) works the best try 5000 epochs
91
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
92
  memory_efficiency_wrapper.step(optimizer, loss)
93
+ optimizer.zero_grad()
94
 
95
  torch.save(finetuner.state_dict(), save_path)
96
 
 
112
  parser.add_argument('--iterations', type=int, required=True)
113
  parser.add_argument('--lr', type=float, required=True)
114
  parser.add_argument('--negative_guidance', type=float, required=True)
115
+ parser.add_argument('--seed', type=int, required=False, default=-1,
116
+ help='Training seed for reproducible results, or -1 to pick a random seed')
117
+ parser.add_argument('--use_adamw8bit', action='store_true')
118
+ parser.add_argument('--use_xformers', action='store_true')
119
+ parser.add_argument('--use_amp', action='store_true')
120
+ parser.add_argument('--use_gradient_checkpointing', action='store_true')
121
 
122
  train(**vars(parser.parse_args()))