File size: 1,912 Bytes
beed73c
 
b92dd65
 
 
beed73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92dd65
 
 
beed73c
b92dd65
beed73c
b92dd65
 
beed73c
b92dd65
beed73c
 
 
b92dd65
 
 
beed73c
 
 
b92dd65
 
beed73c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# main.py

import os
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from huggingface_hub import HfApi
from app import launch_gradio_app
from dreambooth import train_dreambooth

def fine_tune_model(instance_images, class_images, instance_prompt, class_prompt, num_train_steps=800):
    model_name = "runwayml/stable-diffusion-v1-5"
    output_dir = "dreambooth_model"
    
    train_dreambooth(
        pretrained_model_name_or_path=model_name,
        instance_data_dir=instance_images,
        class_data_dir=class_images,
        output_dir=output_dir,
        instance_prompt=instance_prompt,
        class_prompt=class_prompt,
        num_train_steps=num_train_steps
    )
    
    return output_dir

def load_model(model_path):
    pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.enable_xformers_memory_efficient_attention()
    return pipe

def generate_images(pipe, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
    with torch.autocast("cuda"), torch.inference_mode():
        images = pipe(
            prompt, height=int(height), width=int(width),
            negative_prompt=negative_prompt,
            num_images_per_prompt=int(num_samples),
            num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
            generator=torch.Generator(device='cuda')
        ).images
    return images

def push_to_huggingface(model_path, repo_name):
    api = HfApi()
    api.upload_folder(folder_path=model_path, repo_id=repo_name)

if __name__ == "__main__":
    repo_name = "your-huggingface-username/dreambooth-app"
    launch_gradio_app(fine_tune_model, load_model, generate_images, push_to_huggingface, repo_name)