File size: 3,020 Bytes
d953dcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from train_dreambooth import train_dreambooth

class DreamboothApp:
    def __init__(self, model_path, pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"):
        self.model_path = model_path
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.pipe = None
        self.g_cuda = torch.Generator(device='cuda')

    def load_model(self):
        self.pipe = StableDiffusionPipeline.from_pretrained(self.model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.enable_xformers_memory_efficient_attention()

    def train(self, instance_data_dir, class_data_dir, instance_prompt, class_prompt, num_class_images=50, max_train_steps=800, output_dir="stable_diffusion_weights"):
        concepts_list = [
            {
                "instance_prompt": instance_prompt,
                "class_prompt": class_prompt,
                "instance_data_dir": instance_data_dir,
                "class_data_dir": class_data_dir
            }
        ]
        train_dreambooth(pretrained_model_name_or_path=self.pretrained_model_name_or_path,
                         pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse",
                         output_dir=output_dir,
                         revision="fp16",
                         with_prior_preservation=True,
                         prior_loss_weight=1.0,
                         seed=1337,
                         resolution=512,
                         train_batch_size=1,
                         train_text_encoder=True,
                         mixed_precision="fp16",
                         use_8bit_adam=True,
                         gradient_accumulation_steps=1,
                         learning_rate=1e-6,
                         lr_scheduler="constant",
                         lr_warmup_steps=0,
                         num_class_images=num_class_images,
                         sample_batch_size=4,
                         max_train_steps=max_train_steps,
                         save_interval=10000,
                         save_sample_prompt=instance_prompt,
                         concepts_list=concepts_list)
        self.model_path = output_dir

    def inference(self, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, seed=None):
        if seed is not None:
            self.g_cuda.manual_seed(seed)
        with autocast("cuda"), torch.inference_mode():
            return self.pipe(
                prompt, height=int(height), width=int(width),
                negative_prompt=negative_prompt,
                num_images_per_prompt=int(num_samples),
                num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
                generator=self.g_cuda
            ).images