File size: 4,040 Bytes
fcc02a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from collections import OrderedDict

from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.train_tools import get_torch_dtype


def flush():
    torch.cuda.empty_cache()
    gc.collect()


class PureLoraGenerator(BaseExtensionProcess):

    def __init__(self, process_id: int, job, config: OrderedDict):
        super().__init__(process_id, job, config)
        self.output_folder = self.get_conf('output_folder', required=True)
        self.device = self.get_conf('device', 'cuda')
        self.device_torch = torch.device(self.device)
        self.model_config = ModelConfig(**self.get_conf('model', required=True))
        self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
        self.dtype = self.get_conf('dtype', 'float16')
        self.torch_dtype = get_torch_dtype(self.dtype)
        lorm_config = self.get_conf('lorm', None)
        self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None

        self.device_state_preset = get_train_sd_device_state_preset(
            device=torch.device(self.device),
        )

        self.progress_bar = None
        self.sd = StableDiffusion(
            device=self.device,
            model_config=self.model_config,
            dtype=self.dtype,
        )

    def run(self):
        super().run()
        print("Loading model...")
        with torch.no_grad():
            self.sd.load_model()
            self.sd.unet.eval()
            self.sd.unet.to(self.device_torch)
            if isinstance(self.sd.text_encoder, list):
                for te in self.sd.text_encoder:
                    te.eval()
                    te.to(self.device_torch)
            else:
                self.sd.text_encoder.eval()
                self.sd.to(self.device_torch)

            print(f"Converting to LoRM UNet")
            # replace the unet with LoRMUnet
            convert_diffusers_unet_to_lorm(
                self.sd.unet,
                config=self.lorm_config,
            )

            sample_folder = os.path.join(self.output_folder)
            gen_img_config_list = []

            sample_config = self.generate_config
            start_seed = sample_config.seed
            current_seed = start_seed
            for i in range(len(sample_config.prompts)):
                if sample_config.walk_seed:
                    current_seed = start_seed + i

                filename = f"[time]_[count].{self.generate_config.ext}"
                output_path = os.path.join(sample_folder, filename)
                prompt = sample_config.prompts[i]
                extra_args = {}
                gen_img_config_list.append(GenerateImageConfig(
                    prompt=prompt,  # it will autoparse the prompt
                    width=sample_config.width,
                    height=sample_config.height,
                    negative_prompt=sample_config.neg,
                    seed=current_seed,
                    guidance_scale=sample_config.guidance_scale,
                    guidance_rescale=sample_config.guidance_rescale,
                    num_inference_steps=sample_config.sample_steps,
                    network_multiplier=sample_config.network_multiplier,
                    output_path=output_path,
                    output_ext=sample_config.ext,
                    adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
                    **extra_args
                ))

            # send to be generated
            self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
            print("Done generating images")
            # cleanup
            del self.sd
            gc.collect()
            torch.cuda.empty_cache()