RohitGandikota commited on
Commit
47a88ae
·
1 Parent(s): 6491cdf

pushing training code

Browse files
__init__.py CHANGED
@@ -1 +1,2 @@
1
- from trainscripts.textsliders import lora
 
 
1
+ from trainscripts.textsliders import lora
2
+ from trainscripts.textsliders import demotrain
app.py CHANGED
@@ -6,7 +6,7 @@ from diffusers.pipelines import StableDiffusionXLPipeline
6
  StableDiffusionXLPipeline.__call__ = call
7
  import os
8
  from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
9
-
10
 
11
  os.environ['CURL_CA_BUNDLE'] = ''
12
  model_map = {'Age' : 'models/age.pt',
@@ -204,10 +204,26 @@ class Demo:
204
  )
205
 
206
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- # if self.training:
209
- # return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
210
 
 
 
 
 
211
  # if train_method == 'ESD-x':
212
 
213
  # modules = ".*attn2$"
@@ -223,7 +239,7 @@ class Demo:
223
  # modules = ".*attn1$"
224
  # frozen = []
225
 
226
- # randn = torch.randint(1, 10000000, (1,)).item()
227
 
228
  # save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
229
 
@@ -237,7 +253,7 @@ class Demo:
237
 
238
  # model_map['Custom'] = save_path
239
 
240
- # return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
241
  return [None, None, None, None]
242
 
243
  def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
@@ -267,10 +283,12 @@ class Demo:
267
  name = os.path.basename(model_path)
268
  rank = 4
269
  alpha = 1
270
- if 'rank4' in model_path:
271
- rank = 4
272
- if 'rank8' in model_path:
273
- rank = 8
 
 
274
  if 'alpha1' in model_path:
275
  alpha = 1.0
276
  network = LoRANetwork(
 
6
  StableDiffusionXLPipeline.__call__ = call
7
  import os
8
  from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
9
+ from trainscripts.textsliders.demotrain import train_xl
10
 
11
  os.environ['CURL_CA_BUNDLE'] = ''
12
  model_map = {'Age' : 'models/age.pt',
 
204
  )
205
 
206
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
207
+
208
+
209
+ randn = torch.randint(1, 10000000, (1,)).item()
210
+ save_name = f'{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}'
211
+ save_name += f'_alpha-{1}'
212
+ save_name += f'_noxattn'
213
+ save_name += f'_rank_{rank}.pt'
214
+
215
+ if self.training:
216
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
217
+
218
+ self.training = True
219
+ train_xl(target, postive, negative, lr, iterations, config_file, rank, device, attributes)
220
 
221
+ self.training = False
 
222
 
223
+ torch.cuda.empty_cache()
224
+ model_map['Custom Slider'] = f'models/{save_name}'
225
+
226
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom Slider')]
227
  # if train_method == 'ESD-x':
228
 
229
  # modules = ".*attn2$"
 
239
  # modules = ".*attn1$"
240
  # frozen = []
241
 
242
+ #
243
 
244
  # save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
245
 
 
253
 
254
  # model_map['Custom'] = save_path
255
 
256
+ #
257
  return [None, None, None, None]
258
 
259
  def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
 
283
  name = os.path.basename(model_path)
284
  rank = 4
285
  alpha = 1
286
+ if rank in model_path:
287
+ rank = int(model_path.split('_')[-1].replace('.pt',''))
288
+ # if 'rank4' in model_path:
289
+ # rank = 4
290
+ # if 'rank8' in model_path:
291
+ # rank = 8
292
  if 'alpha1' in model_path:
293
  alpha = 1.0
294
  network = LoRANetwork(
trainscripts/textsliders/data/config-xl.yaml CHANGED
@@ -19,7 +19,7 @@ train:
19
  save:
20
  name: "temp"
21
  path: "./models"
22
- per_steps: 500
23
  precision: "bfloat16"
24
  logging:
25
  use_wandb: false
 
19
  save:
20
  name: "temp"
21
  path: "./models"
22
+ per_steps: 5000000
23
  precision: "bfloat16"
24
  logging:
25
  use_wandb: false
trainscripts/textsliders/data/prompts-xl.yaml CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  ####################################################################################################### AGE SLIDER
2
  # - target: "male person" # what word for erasing the positive concept from
3
  # positive: "male person, very old" # concept to erase
@@ -257,24 +266,24 @@
257
  # dynamic_resolution: false
258
  # batch_size: 1
259
  ####################################################################################################### SCULPTURE SLIDER
260
- - target: "male person" # what word for erasing the positive concept from
261
- positive: "male person, cement sculpture, cement greek statue style" # concept to erase
262
- unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
263
- neutral: "male person" # starting point for conditioning the target
264
- action: "enhance" # erase or enhance
265
- guidance_scale: 4
266
- resolution: 512
267
- dynamic_resolution: false
268
- batch_size: 1
269
- - target: "female person" # what word for erasing the positive concept from
270
- positive: "female person, cement sculpture, cement greek statue style" # concept to erase
271
- unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
272
- neutral: "female person" # starting point for conditioning the target
273
- action: "enhance" # erase or enhance
274
- guidance_scale: 4
275
- resolution: 512
276
- dynamic_resolution: false
277
- batch_size: 1
278
  ####################################################################################################### METAL SLIDER
279
  # - target: "" # what word for erasing the positive concept from
280
  # positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase
 
1
+ - target: "" # what word for erasing the positive concept from
2
+ positive: "" # concept to erase
3
+ unconditional: "" # word to take the difference from the positive concept
4
+ neutral: "" # starting point for conditioning the target
5
+ action: "enhance" # erase or enhance
6
+ guidance_scale: 4
7
+ resolution: 512
8
+ dynamic_resolution: false
9
+ batch_size: 1
10
  ####################################################################################################### AGE SLIDER
11
  # - target: "male person" # what word for erasing the positive concept from
12
  # positive: "male person, very old" # concept to erase
 
266
  # dynamic_resolution: false
267
  # batch_size: 1
268
  ####################################################################################################### SCULPTURE SLIDER
269
+ # - target: "male person" # what word for erasing the positive concept from
270
+ # positive: "male person, cement sculpture, cement greek statue style" # concept to erase
271
+ # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
272
+ # neutral: "male person" # starting point for conditioning the target
273
+ # action: "enhance" # erase or enhance
274
+ # guidance_scale: 4
275
+ # resolution: 512
276
+ # dynamic_resolution: false
277
+ # batch_size: 1
278
+ # - target: "female person" # what word for erasing the positive concept from
279
+ # positive: "female person, cement sculpture, cement greek statue style" # concept to erase
280
+ # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
281
+ # neutral: "female person" # starting point for conditioning the target
282
+ # action: "enhance" # erase or enhance
283
+ # guidance_scale: 4
284
+ # resolution: 512
285
+ # dynamic_resolution: false
286
+ # batch_size: 1
287
  ####################################################################################################### METAL SLIDER
288
  # - target: "" # what word for erasing the positive concept from
289
  # positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase
trainscripts/textsliders/demotrain.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+
15
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
+ import train_util
17
+ import model_util
18
+ import prompt_util
19
+ from prompt_util import (
20
+ PromptEmbedsCache,
21
+ PromptEmbedsPair,
22
+ PromptSettings,
23
+ PromptEmbedsXL,
24
+ )
25
+ import debug_util
26
+ import config_util
27
+ from config_util import RootConfig
28
+
29
+ import wandb
30
+
31
+ NUM_IMAGES_PER_PROMPT = 1
32
+
33
+
34
+ def flush():
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+
39
+ def train(
40
+ config: RootConfig,
41
+ prompts: list[PromptSettings],
42
+ device,
43
+ ):
44
+ metadata = {
45
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
46
+ "config": config.json(),
47
+ }
48
+ save_path = Path(config.save.path)
49
+
50
+ modules = DEFAULT_TARGET_REPLACE
51
+ if config.network.type == "c3lier":
52
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
53
+
54
+ if config.logging.verbose:
55
+ print(metadata)
56
+
57
+ if config.logging.use_wandb:
58
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
59
+
60
+ weight_dtype = config_util.parse_precision(config.train.precision)
61
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
62
+
63
+ (
64
+ tokenizers,
65
+ text_encoders,
66
+ unet,
67
+ noise_scheduler,
68
+ ) = model_util.load_models_xl(
69
+ config.pretrained_model.name_or_path,
70
+ scheduler_name=config.train.noise_scheduler,
71
+ )
72
+
73
+ for text_encoder in text_encoders:
74
+ text_encoder.to(device, dtype=weight_dtype)
75
+ text_encoder.requires_grad_(False)
76
+ text_encoder.eval()
77
+
78
+ unet.to(device, dtype=weight_dtype)
79
+ if config.other.use_xformers:
80
+ unet.enable_xformers_memory_efficient_attention()
81
+ unet.requires_grad_(False)
82
+ unet.eval()
83
+
84
+ network = LoRANetwork(
85
+ unet,
86
+ rank=config.network.rank,
87
+ multiplier=1.0,
88
+ alpha=config.network.alpha,
89
+ train_method=config.network.training_method,
90
+ ).to(device, dtype=weight_dtype)
91
+
92
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
93
+ #optimizer_args
94
+ optimizer_kwargs = {}
95
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
96
+ for arg in config.train.optimizer_args.split(" "):
97
+ key, value = arg.split("=")
98
+ value = ast.literal_eval(value)
99
+ optimizer_kwargs[key] = value
100
+
101
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
102
+ lr_scheduler = train_util.get_lr_scheduler(
103
+ config.train.lr_scheduler,
104
+ optimizer,
105
+ max_iterations=config.train.iterations,
106
+ lr_min=config.train.lr / 100,
107
+ )
108
+ criteria = torch.nn.MSELoss()
109
+
110
+ print("Prompts")
111
+ for settings in prompts:
112
+ print(settings)
113
+
114
+ # debug
115
+ debug_util.check_requires_grad(network)
116
+ debug_util.check_training_mode(network)
117
+
118
+ cache = PromptEmbedsCache()
119
+ prompt_pairs: list[PromptEmbedsPair] = []
120
+
121
+ with torch.no_grad():
122
+ for settings in prompts:
123
+ print(settings)
124
+ for prompt in [
125
+ settings.target,
126
+ settings.positive,
127
+ settings.neutral,
128
+ settings.unconditional,
129
+ ]:
130
+ if cache[prompt] == None:
131
+ tex_embs, pool_embs = train_util.encode_prompts_xl(
132
+ tokenizers,
133
+ text_encoders,
134
+ [prompt],
135
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
136
+ )
137
+ cache[prompt] = PromptEmbedsXL(
138
+ tex_embs,
139
+ pool_embs
140
+ )
141
+
142
+ prompt_pairs.append(
143
+ PromptEmbedsPair(
144
+ criteria,
145
+ cache[settings.target],
146
+ cache[settings.positive],
147
+ cache[settings.unconditional],
148
+ cache[settings.neutral],
149
+ settings,
150
+ )
151
+ )
152
+
153
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
154
+ del tokenizer, text_encoder
155
+
156
+ flush()
157
+
158
+ pbar = tqdm(range(config.train.iterations))
159
+
160
+ loss = None
161
+
162
+ for i in pbar:
163
+ with torch.no_grad():
164
+ noise_scheduler.set_timesteps(
165
+ config.train.max_denoising_steps, device=device
166
+ )
167
+
168
+ optimizer.zero_grad()
169
+
170
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
171
+ torch.randint(0, len(prompt_pairs), (1,)).item()
172
+ ]
173
+
174
+ # 1 ~ 49 からランダム
175
+ timesteps_to = torch.randint(
176
+ 1, config.train.max_denoising_steps, (1,)
177
+ ).item()
178
+
179
+ height, width = prompt_pair.resolution, prompt_pair.resolution
180
+ if prompt_pair.dynamic_resolution:
181
+ height, width = train_util.get_random_resolution_in_bucket(
182
+ prompt_pair.resolution
183
+ )
184
+
185
+ if config.logging.verbose:
186
+ print("gudance_scale:", prompt_pair.guidance_scale)
187
+ print("resolution:", prompt_pair.resolution)
188
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
189
+ if prompt_pair.dynamic_resolution:
190
+ print("bucketed resolution:", (height, width))
191
+ print("batch_size:", prompt_pair.batch_size)
192
+ print("dynamic_crops:", prompt_pair.dynamic_crops)
193
+
194
+ latents = train_util.get_initial_latents(
195
+ noise_scheduler, prompt_pair.batch_size, height, width, 1
196
+ ).to(device, dtype=weight_dtype)
197
+
198
+ add_time_ids = train_util.get_add_time_ids(
199
+ height,
200
+ width,
201
+ dynamic_crops=prompt_pair.dynamic_crops,
202
+ dtype=weight_dtype,
203
+ ).to(device, dtype=weight_dtype)
204
+
205
+ with network:
206
+ # ちょっとデノイズされれたものが返る
207
+ denoised_latents = train_util.diffusion_xl(
208
+ unet,
209
+ noise_scheduler,
210
+ latents, # 単純なノイズのlatentsを渡す
211
+ text_embeddings=train_util.concat_embeddings(
212
+ prompt_pair.unconditional.text_embeds,
213
+ prompt_pair.target.text_embeds,
214
+ prompt_pair.batch_size,
215
+ ),
216
+ add_text_embeddings=train_util.concat_embeddings(
217
+ prompt_pair.unconditional.pooled_embeds,
218
+ prompt_pair.target.pooled_embeds,
219
+ prompt_pair.batch_size,
220
+ ),
221
+ add_time_ids=train_util.concat_embeddings(
222
+ add_time_ids, add_time_ids, prompt_pair.batch_size
223
+ ),
224
+ start_timesteps=0,
225
+ total_timesteps=timesteps_to,
226
+ guidance_scale=3,
227
+ )
228
+
229
+ noise_scheduler.set_timesteps(1000)
230
+
231
+ current_timestep = noise_scheduler.timesteps[
232
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
233
+ ]
234
+
235
+ # with network: の外では空のLoRAのみが有効になる
236
+ positive_latents = train_util.predict_noise_xl(
237
+ unet,
238
+ noise_scheduler,
239
+ current_timestep,
240
+ denoised_latents,
241
+ text_embeddings=train_util.concat_embeddings(
242
+ prompt_pair.unconditional.text_embeds,
243
+ prompt_pair.positive.text_embeds,
244
+ prompt_pair.batch_size,
245
+ ),
246
+ add_text_embeddings=train_util.concat_embeddings(
247
+ prompt_pair.unconditional.pooled_embeds,
248
+ prompt_pair.positive.pooled_embeds,
249
+ prompt_pair.batch_size,
250
+ ),
251
+ add_time_ids=train_util.concat_embeddings(
252
+ add_time_ids, add_time_ids, prompt_pair.batch_size
253
+ ),
254
+ guidance_scale=1,
255
+ ).to(device, dtype=weight_dtype)
256
+ neutral_latents = train_util.predict_noise_xl(
257
+ unet,
258
+ noise_scheduler,
259
+ current_timestep,
260
+ denoised_latents,
261
+ text_embeddings=train_util.concat_embeddings(
262
+ prompt_pair.unconditional.text_embeds,
263
+ prompt_pair.neutral.text_embeds,
264
+ prompt_pair.batch_size,
265
+ ),
266
+ add_text_embeddings=train_util.concat_embeddings(
267
+ prompt_pair.unconditional.pooled_embeds,
268
+ prompt_pair.neutral.pooled_embeds,
269
+ prompt_pair.batch_size,
270
+ ),
271
+ add_time_ids=train_util.concat_embeddings(
272
+ add_time_ids, add_time_ids, prompt_pair.batch_size
273
+ ),
274
+ guidance_scale=1,
275
+ ).to(device, dtype=weight_dtype)
276
+ unconditional_latents = train_util.predict_noise_xl(
277
+ unet,
278
+ noise_scheduler,
279
+ current_timestep,
280
+ denoised_latents,
281
+ text_embeddings=train_util.concat_embeddings(
282
+ prompt_pair.unconditional.text_embeds,
283
+ prompt_pair.unconditional.text_embeds,
284
+ prompt_pair.batch_size,
285
+ ),
286
+ add_text_embeddings=train_util.concat_embeddings(
287
+ prompt_pair.unconditional.pooled_embeds,
288
+ prompt_pair.unconditional.pooled_embeds,
289
+ prompt_pair.batch_size,
290
+ ),
291
+ add_time_ids=train_util.concat_embeddings(
292
+ add_time_ids, add_time_ids, prompt_pair.batch_size
293
+ ),
294
+ guidance_scale=1,
295
+ ).to(device, dtype=weight_dtype)
296
+
297
+ if config.logging.verbose:
298
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
299
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
300
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
301
+
302
+ with network:
303
+ target_latents = train_util.predict_noise_xl(
304
+ unet,
305
+ noise_scheduler,
306
+ current_timestep,
307
+ denoised_latents,
308
+ text_embeddings=train_util.concat_embeddings(
309
+ prompt_pair.unconditional.text_embeds,
310
+ prompt_pair.target.text_embeds,
311
+ prompt_pair.batch_size,
312
+ ),
313
+ add_text_embeddings=train_util.concat_embeddings(
314
+ prompt_pair.unconditional.pooled_embeds,
315
+ prompt_pair.target.pooled_embeds,
316
+ prompt_pair.batch_size,
317
+ ),
318
+ add_time_ids=train_util.concat_embeddings(
319
+ add_time_ids, add_time_ids, prompt_pair.batch_size
320
+ ),
321
+ guidance_scale=1,
322
+ ).to(device, dtype=weight_dtype)
323
+
324
+ if config.logging.verbose:
325
+ print("target_latents:", target_latents[0, 0, :5, :5])
326
+
327
+ positive_latents.requires_grad = False
328
+ neutral_latents.requires_grad = False
329
+ unconditional_latents.requires_grad = False
330
+
331
+ loss = prompt_pair.loss(
332
+ target_latents=target_latents,
333
+ positive_latents=positive_latents,
334
+ neutral_latents=neutral_latents,
335
+ unconditional_latents=unconditional_latents,
336
+ )
337
+
338
+ # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない
339
+ pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}")
340
+ if config.logging.use_wandb:
341
+ wandb.log(
342
+ {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]}
343
+ )
344
+
345
+ loss.backward()
346
+ optimizer.step()
347
+ lr_scheduler.step()
348
+
349
+ del (
350
+ positive_latents,
351
+ neutral_latents,
352
+ unconditional_latents,
353
+ target_latents,
354
+ latents,
355
+ )
356
+ flush()
357
+
358
+ # if (
359
+ # i % config.save.per_steps == 0
360
+ # and i != 0
361
+ # and i != config.train.iterations - 1
362
+ # ):
363
+ # print("Saving...")
364
+ # save_path.mkdir(parents=True, exist_ok=True)
365
+ # network.save_weights(
366
+ # save_path / f"{config.save.name}_{i}steps.pt",
367
+ # dtype=save_weight_dtype,
368
+ # )
369
+
370
+ print("Saving...")
371
+ save_path.mkdir(parents=True, exist_ok=True)
372
+ network.save_weights(
373
+ save_path / f"{config.save.name}",
374
+ dtype=save_weight_dtype,
375
+ )
376
+
377
+ del (
378
+ unet,
379
+ noise_scheduler,
380
+ loss,
381
+ optimizer,
382
+ network,
383
+ )
384
+
385
+ flush()
386
+
387
+ print("Done.")
388
+
389
+
390
+ # def main(args):
391
+ # config_file = args.config_file
392
+
393
+ # config = config_util.load_config_from_yaml(config_file)
394
+ # if args.name is not None:
395
+ # config.save.name = args.name
396
+ # attributes = []
397
+ # if args.attributes is not None:
398
+ # attributes = args.attributes.split(',')
399
+ # attributes = [a.strip() for a in attributes]
400
+
401
+ # config.network.alpha = args.alpha
402
+ # config.network.rank = args.rank
403
+ # config.save.name += f'_alpha{args.alpha}'
404
+ # config.save.name += f'_rank{config.network.rank }'
405
+ # config.save.name += f'_{config.network.training_method}'
406
+ # config.save.path += f'/{config.save.name}'
407
+
408
+ # prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
409
+
410
+ # device = torch.device(f"cuda:{args.device}")
411
+ # train(config, prompts, device)
412
+
413
+
414
+ def train_xl(target, postive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
415
+
416
+ config = config_util.load_config_from_yaml(config_file)
417
+ randn = torch.randint(1, 10000000, (1,)).item()
418
+ config.save.name = save_name
419
+
420
+ config.train.lr = float(lr)
421
+ config.train.iterations=int(iterations)
422
+
423
+ if attributes is not None:
424
+ attributes = attributes.split(',')
425
+ attributes = [a.strip() for a in attributes]
426
+ config.network.alpha = 1.0
427
+ config.network.rank = rank
428
+
429
+ config.save.path += f'/{config.save.name}'
430
+
431
+ prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
432
+
433
+ device = torch.device(f"cuda:{device}")
434
+ train(config, prompts, device)
trainscripts/textsliders/prompt_util.py CHANGED
@@ -148,9 +148,18 @@ class PromptEmbedsPair:
148
  raise ValueError("action must be erase or enhance")
149
 
150
 
151
- def load_prompts_from_yaml(path, attributes = []):
152
  with open(path, "r") as f:
153
  prompts = yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
154
  print(prompts)
155
  if len(prompts) == 0:
156
  raise ValueError("prompts file is empty")
 
148
  raise ValueError("action must be erase or enhance")
149
 
150
 
151
+ def load_prompts_from_yaml(path, target, positive, negative, attributes = []):
152
  with open(path, "r") as f:
153
  prompts = yaml.safe_load(f)
154
+ new = []
155
+ for prompt in prompts:
156
+ copy_ = copy.deepcopy(prompt)
157
+ copy_['target'] = target
158
+ copy_['positive'] = positive
159
+ copy_['neutral'] = target
160
+ copy_['unconditional'] = negative
161
+ new.append(copy_)
162
+ prompts = new
163
  print(prompts)
164
  if len(prompts) == 0:
165
  raise ValueError("prompts file is empty")