Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline, DDIMScheduler | |
from train_dreambooth import train_dreambooth | |
class DreamboothApp: | |
def __init__(self, model_path, pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"): | |
self.model_path = model_path | |
self.pretrained_model_name_or_path = pretrained_model_name_or_path | |
self.pipe = None | |
self.g_cuda = torch.Generator(device='cuda') | |
def load_model(self): | |
self.pipe = StableDiffusionPipeline.from_pretrained(self.model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda") | |
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) | |
self.pipe.enable_xformers_memory_efficient_attention() | |
def train(self, instance_data_dir, class_data_dir, instance_prompt, class_prompt, num_class_images=50, max_train_steps=800, output_dir="stable_diffusion_weights"): | |
concepts_list = [ | |
{ | |
"instance_prompt": instance_prompt, | |
"class_prompt": class_prompt, | |
"instance_data_dir": instance_data_dir, | |
"class_data_dir": class_data_dir | |
} | |
] | |
train_dreambooth(pretrained_model_name_or_path=self.pretrained_model_name_or_path, | |
pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse", | |
output_dir=output_dir, | |
revision="fp16", | |
with_prior_preservation=True, | |
prior_loss_weight=1.0, | |
seed=1337, | |
resolution=512, | |
train_batch_size=1, | |
train_text_encoder=True, | |
mixed_precision="fp16", | |
use_8bit_adam=True, | |
gradient_accumulation_steps=1, | |
learning_rate=1e-6, | |
lr_scheduler="constant", | |
lr_warmup_steps=0, | |
num_class_images=num_class_images, | |
sample_batch_size=4, | |
max_train_steps=max_train_steps, | |
save_interval=10000, | |
save_sample_prompt=instance_prompt, | |
concepts_list=concepts_list) | |
self.model_path = output_dir | |
def inference(self, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, seed=None): | |
if seed is not None: | |
self.g_cuda.manual_seed(seed) | |
with autocast("cuda"), torch.inference_mode(): | |
return self.pipe( | |
prompt, height=int(height), width=int(width), | |
negative_prompt=negative_prompt, | |
num_images_per_prompt=int(num_samples), | |
num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale, | |
generator=self.g_cuda | |
).images |