Singularity666 commited on
Commit
d953dcd
·
verified ·
1 Parent(s): 4423a12

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +62 -0
main.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
5
+ from train_dreambooth import train_dreambooth
6
+
7
+ class DreamboothApp:
8
+ def __init__(self, model_path, pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"):
9
+ self.model_path = model_path
10
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
11
+ self.pipe = None
12
+ self.g_cuda = torch.Generator(device='cuda')
13
+
14
+ def load_model(self):
15
+ self.pipe = StableDiffusionPipeline.from_pretrained(self.model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
16
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
17
+ self.pipe.enable_xformers_memory_efficient_attention()
18
+
19
+ def train(self, instance_data_dir, class_data_dir, instance_prompt, class_prompt, num_class_images=50, max_train_steps=800, output_dir="stable_diffusion_weights"):
20
+ concepts_list = [
21
+ {
22
+ "instance_prompt": instance_prompt,
23
+ "class_prompt": class_prompt,
24
+ "instance_data_dir": instance_data_dir,
25
+ "class_data_dir": class_data_dir
26
+ }
27
+ ]
28
+ train_dreambooth(pretrained_model_name_or_path=self.pretrained_model_name_or_path,
29
+ pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse",
30
+ output_dir=output_dir,
31
+ revision="fp16",
32
+ with_prior_preservation=True,
33
+ prior_loss_weight=1.0,
34
+ seed=1337,
35
+ resolution=512,
36
+ train_batch_size=1,
37
+ train_text_encoder=True,
38
+ mixed_precision="fp16",
39
+ use_8bit_adam=True,
40
+ gradient_accumulation_steps=1,
41
+ learning_rate=1e-6,
42
+ lr_scheduler="constant",
43
+ lr_warmup_steps=0,
44
+ num_class_images=num_class_images,
45
+ sample_batch_size=4,
46
+ max_train_steps=max_train_steps,
47
+ save_interval=10000,
48
+ save_sample_prompt=instance_prompt,
49
+ concepts_list=concepts_list)
50
+ self.model_path = output_dir
51
+
52
+ def inference(self, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, seed=None):
53
+ if seed is not None:
54
+ self.g_cuda.manual_seed(seed)
55
+ with autocast("cuda"), torch.inference_mode():
56
+ return self.pipe(
57
+ prompt, height=int(height), width=int(width),
58
+ negative_prompt=negative_prompt,
59
+ num_images_per_prompt=int(num_samples),
60
+ num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
61
+ generator=self.g_cuda
62
+ ).images