"
#
# pbar = tqdm.tqdm(total=steps - initial_step)
# try:
# sd_hijack_checkpoint.add()
#
# for _ in range((steps-initial_step) * gradient_step):
# if scheduler.finished:
# break
# if shared.state.interrupted:
# break
# for j, batch in enumerate(dl):
# # works as a drop_last=True for gradient accumulation
# if j == max_steps_per_epoch:
# break
# scheduler.apply(optimizer, hypernetwork.step)
# if scheduler.finished:
# break
# if shared.state.interrupted:
# break
#
# if clip_grad:
# clip_grad_sched.step(hypernetwork.step)
#
# with devices.autocast():
# x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
# if use_weight:
# w = batch.weight.to(devices.device, non_blocking=pin_memory)
# if tag_drop_out != 0 or shuffle_tags:
# shared.sd_model.cond_stage_model.to(devices.device)
# c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
# shared.sd_model.cond_stage_model.to(devices.cpu)
# else:
# c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
# if use_weight:
# loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
# del w
# else:
# loss = shared.sd_model.forward(x, c)[0] / gradient_step
# del x
# del c
#
# _loss_step += loss.item()
# scaler.scale(loss).backward()
#
# # go back until we reach gradient accumulation steps
# if (j + 1) % gradient_step != 0:
# continue
# loss_logging.append(_loss_step)
# if clip_grad:
# clip_grad(weights, clip_grad_sched.learn_rate)
#
# scaler.step(optimizer)
# scaler.update()
# hypernetwork.step += 1
# pbar.update()
# optimizer.zero_grad(set_to_none=True)
# loss_step = _loss_step
# _loss_step = 0
#
# steps_done = hypernetwork.step + 1
#
# epoch_num = hypernetwork.step // steps_per_epoch
# epoch_step = hypernetwork.step % steps_per_epoch
#
# description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
# pbar.set_description(description)
# if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# # Before saving, change name to match current checkpoint.
# hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
# last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
# hypernetwork.optimizer_name = optimizer_name
# if shared.opts.save_optimizer_state:
# hypernetwork.optimizer_state_dict = optimizer.state_dict()
# save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
# hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
#
#
#
# if shared.opts.training_enable_tensorboard:
# epoch_num = hypernetwork.step // len(ds)
# epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
# mean_loss = sum(loss_logging) / len(loss_logging)
# textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
#
# textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
# "loss": f"{loss_step:.7f}",
# "learn_rate": scheduler.learn_rate
# })
#
# if images_dir is not None and steps_done % create_image_every == 0:
# forced_filename = f'{hypernetwork_name}-{steps_done}'
# last_saved_image = os.path.join(images_dir, forced_filename)
# hypernetwork.eval()
# rng_state = torch.get_rng_state()
# cuda_rng_state = None
# if torch.cuda.is_available():
# cuda_rng_state = torch.cuda.get_rng_state_all()
# shared.sd_model.cond_stage_model.to(devices.device)
# shared.sd_model.first_stage_model.to(devices.device)
#
# p = processing.StableDiffusionProcessingTxt2Img(
# sd_model=shared.sd_model,
# do_not_save_grid=True,
# do_not_save_samples=True,
# )
#
# p.disable_extra_networks = True
#
# if preview_from_txt2img:
# p.prompt = preview_prompt
# p.negative_prompt = preview_negative_prompt
# p.steps = preview_steps
# p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
# p.cfg_scale = preview_cfg_scale
# p.seed = preview_seed
# p.width = preview_width
# p.height = preview_height
# else:
# p.prompt = batch.cond_text[0]
# p.steps = 20
# p.width = training_width
# p.height = training_height
#
# preview_text = p.prompt
#
# with closing(p):
# processed = processing.process_images(p)
# image = processed.images[0] if len(processed.images) > 0 else None
#
# if unload:
# shared.sd_model.cond_stage_model.to(devices.cpu)
# shared.sd_model.first_stage_model.to(devices.cpu)
# torch.set_rng_state(rng_state)
# if torch.cuda.is_available():
# torch.cuda.set_rng_state_all(cuda_rng_state)
# hypernetwork.train()
# if image is not None:
# shared.state.assign_current_image(image)
# if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
# textual_inversion.tensorboard_add_image(tensorboard_writer,
# f"Validation at epoch {epoch_num}", image,
# hypernetwork.step)
# last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
# last_saved_image += f", prompt: {preview_text}"
#
# shared.state.job_no = hypernetwork.step
#
# shared.state.textinfo = f"""
#
# Loss: {loss_step:.7f}
# Step: {steps_done}
# Last prompt: {html.escape(batch.cond_text[0])}
# Last saved hypernetwork: {html.escape(last_saved_file)}
# Last saved image: {html.escape(last_saved_image)}
#
# """
# except Exception:
# errors.report("Exception in training hypernetwork", exc_info=True)
# finally:
# pbar.leave = False
# pbar.close()
# hypernetwork.eval()
# sd_hijack_checkpoint.remove()
#
#
#
# filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
# hypernetwork.optimizer_name = optimizer_name
# if shared.opts.save_optimizer_state:
# hypernetwork.optimizer_state_dict = optimizer.state_dict()
# save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
#
# del optimizer
# hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
# shared.sd_model.cond_stage_model.to(devices.device)
# shared.sd_model.first_stage_model.to(devices.device)
# shared.parallel_processing_allowed = old_parallel_processing_allowed
#
# return hypernetwork, filename
#
# def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
# old_hypernetwork_name = hypernetwork.name
# old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
# old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
# try:
# hypernetwork.sd_checkpoint = checkpoint.shorthash
# hypernetwork.sd_checkpoint_name = checkpoint.model_name
# hypernetwork.name = hypernetwork_name
# hypernetwork.save(filename)
# except:
# hypernetwork.sd_checkpoint = old_sd_checkpoint
# hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
# hypernetwork.name = old_hypernetwork_name
# raise