import os from collections import OrderedDict from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm from toolkit.sd_device_states_presets import get_train_sd_device_state_preset from toolkit.stable_diffusion_model import StableDiffusion import gc import torch from jobs.process import BaseExtensionProcess from toolkit.train_tools import get_torch_dtype def flush(): torch.cuda.empty_cache() gc.collect() class PureLoraGenerator(BaseExtensionProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) self.output_folder = self.get_conf('output_folder', required=True) self.device = self.get_conf('device', 'cuda') self.device_torch = torch.device(self.device) self.model_config = ModelConfig(**self.get_conf('model', required=True)) self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) self.dtype = self.get_conf('dtype', 'float16') self.torch_dtype = get_torch_dtype(self.dtype) lorm_config = self.get_conf('lorm', None) self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None self.device_state_preset = get_train_sd_device_state_preset( device=torch.device(self.device), ) self.progress_bar = None self.sd = StableDiffusion( device=self.device, model_config=self.model_config, dtype=self.dtype, ) def run(self): super().run() print("Loading model...") with torch.no_grad(): self.sd.load_model() self.sd.unet.eval() self.sd.unet.to(self.device_torch) if isinstance(self.sd.text_encoder, list): for te in self.sd.text_encoder: te.eval() te.to(self.device_torch) else: self.sd.text_encoder.eval() self.sd.to(self.device_torch) print(f"Converting to LoRM UNet") # replace the unet with LoRMUnet convert_diffusers_unet_to_lorm( self.sd.unet, config=self.lorm_config, ) sample_folder = os.path.join(self.output_folder) gen_img_config_list = [] sample_config = self.generate_config start_seed = sample_config.seed current_seed = start_seed for i in range(len(sample_config.prompts)): if sample_config.walk_seed: current_seed = start_seed + i filename = f"[time]_[count].{self.generate_config.ext}" output_path = os.path.join(sample_folder, filename) prompt = sample_config.prompts[i] extra_args = {} gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt width=sample_config.width, height=sample_config.height, negative_prompt=sample_config.neg, seed=current_seed, guidance_scale=sample_config.guidance_scale, guidance_rescale=sample_config.guidance_rescale, num_inference_steps=sample_config.sample_steps, network_multiplier=sample_config.network_multiplier, output_path=output_path, output_ext=sample_config.ext, adapter_conditioning_scale=sample_config.adapter_conditioning_scale, **extra_args )) # send to be generated self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) print("Done generating images") # cleanup del self.sd gc.collect() torch.cuda.empty_cache()