Spaces:
Running
Running
File size: 4,040 Bytes
fcc02a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
from collections import OrderedDict
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.train_tools import get_torch_dtype
def flush():
torch.cuda.empty_cache()
gc.collect()
class PureLoraGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.device = self.get_conf('device', 'cuda')
self.device_torch = torch.device(self.device)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
lorm_config = self.get_conf('lorm', None)
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
self.device_state_preset = get_train_sd_device_state_preset(
device=torch.device(self.device),
)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
def run(self):
super().run()
print("Loading model...")
with torch.no_grad():
self.sd.load_model()
self.sd.unet.eval()
self.sd.unet.to(self.device_torch)
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
te.to(self.device_torch)
else:
self.sd.text_encoder.eval()
self.sd.to(self.device_torch)
print(f"Converting to LoRM UNet")
# replace the unet with LoRMUnet
convert_diffusers_unet_to_lorm(
self.sd.unet,
config=self.lorm_config,
)
sample_folder = os.path.join(self.output_folder)
gen_img_config_list = []
sample_config = self.generate_config
start_seed = sample_config.seed
current_seed = start_seed
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
filename = f"[time]_[count].{self.generate_config.ext}"
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
extra_args = {}
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
seed=current_seed,
guidance_scale=sample_config.guidance_scale,
guidance_rescale=sample_config.guidance_rescale,
num_inference_steps=sample_config.sample_steps,
network_multiplier=sample_config.network_multiplier,
output_path=output_path,
output_ext=sample_config.ext,
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
**extra_args
))
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()
|