import os import torch from torch import autocast from diffusers import StableDiffusionPipeline, DDIMScheduler from train_dreambooth import train_dreambooth class DreamboothApp: def __init__(self, model_path, pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"): self.model_path = model_path self.pretrained_model_name_or_path = pretrained_model_name_or_path self.pipe = None self.g_cuda = torch.Generator(device='cuda') def load_model(self): self.pipe = StableDiffusionPipeline.from_pretrained(self.model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda") self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) self.pipe.enable_xformers_memory_efficient_attention() def train(self, instance_data_dir, class_data_dir, instance_prompt, class_prompt, num_class_images=50, max_train_steps=800, output_dir="stable_diffusion_weights"): concepts_list = [ { "instance_prompt": instance_prompt, "class_prompt": class_prompt, "instance_data_dir": instance_data_dir, "class_data_dir": class_data_dir } ] train_dreambooth(pretrained_model_name_or_path=self.pretrained_model_name_or_path, pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse", output_dir=output_dir, revision="fp16", with_prior_preservation=True, prior_loss_weight=1.0, seed=1337, resolution=512, train_batch_size=1, train_text_encoder=True, mixed_precision="fp16", use_8bit_adam=True, gradient_accumulation_steps=1, learning_rate=1e-6, lr_scheduler="constant", lr_warmup_steps=0, num_class_images=num_class_images, sample_batch_size=4, max_train_steps=max_train_steps, save_interval=10000, save_sample_prompt=instance_prompt, concepts_list=concepts_list) self.model_path = output_dir def inference(self, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, seed=None): if seed is not None: self.g_cuda.manual_seed(seed) with autocast("cuda"), torch.inference_mode(): return self.pipe( prompt, height=int(height), width=int(width), negative_prompt=negative_prompt, num_images_per_prompt=int(num_samples), num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale, generator=self.g_cuda ).images