editx / main.py
Singularity666's picture
Create main.py
d953dcd verified
raw
history blame
3.02 kB
import os
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from train_dreambooth import train_dreambooth
class DreamboothApp:
def __init__(self, model_path, pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"):
self.model_path = model_path
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.pipe = None
self.g_cuda = torch.Generator(device='cuda')
def load_model(self):
self.pipe = StableDiffusionPipeline.from_pretrained(self.model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe.enable_xformers_memory_efficient_attention()
def train(self, instance_data_dir, class_data_dir, instance_prompt, class_prompt, num_class_images=50, max_train_steps=800, output_dir="stable_diffusion_weights"):
concepts_list = [
{
"instance_prompt": instance_prompt,
"class_prompt": class_prompt,
"instance_data_dir": instance_data_dir,
"class_data_dir": class_data_dir
}
]
train_dreambooth(pretrained_model_name_or_path=self.pretrained_model_name_or_path,
pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse",
output_dir=output_dir,
revision="fp16",
with_prior_preservation=True,
prior_loss_weight=1.0,
seed=1337,
resolution=512,
train_batch_size=1,
train_text_encoder=True,
mixed_precision="fp16",
use_8bit_adam=True,
gradient_accumulation_steps=1,
learning_rate=1e-6,
lr_scheduler="constant",
lr_warmup_steps=0,
num_class_images=num_class_images,
sample_batch_size=4,
max_train_steps=max_train_steps,
save_interval=10000,
save_sample_prompt=instance_prompt,
concepts_list=concepts_list)
self.model_path = output_dir
def inference(self, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, seed=None):
if seed is not None:
self.g_cuda.manual_seed(seed)
with autocast("cuda"), torch.inference_mode():
return self.pipe(
prompt, height=int(height), width=int(width),
negative_prompt=negative_prompt,
num_images_per_prompt=int(num_samples),
num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
generator=self.g_cuda
).images