File size: 4,653 Bytes
b92dd65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import shutil
from pathlib import Path
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
from torch import autocast

# Define necessary paths and variables
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/output_model"
INSTANCE_PROMPT = "photo of {identifier} person"
CLASS_PROMPT = "photo of a person"
SEED = 1337
RESOLUTION = 512
TRAIN_BATCH_SIZE = 1
LEARNING_RATE = 1e-6
MAX_TRAIN_STEPS = 800
GUIDANCE_SCALE = 8.0
NUM_INFERENCE_STEPS = 50

# Function to fine-tune the model
def fine_tune_model(instance_data_dir, identifier):
    # Set up paths
    instance_prompt = INSTANCE_PROMPT.format(identifier=identifier)
    concepts_list = [
        {
            "instance_prompt": instance_prompt,
            "class_prompt": CLASS_PROMPT,
            "instance_data_dir": instance_data_dir,
            "class_data_dir": "/sample_data/person"  # Placeholder for regularization images
        }
    ]

    # Save concepts_list.json
    with open("concepts_list.json", "w") as f:
        json.dump(concepts_list, f, indent=4)

    # Run the training script
    os.system(f"""
    python3 train_dreambooth.py \
        --pretrained_model_name_or_path={MODEL_NAME} \
        --output_dir={OUTPUT_DIR} \
        --revision="fp16" \
        --with_prior_preservation --prior_loss_weight=1.0 \
        --seed={SEED} \
        --resolution={RESOLUTION} \
        --train_batch_size={TRAIN_BATCH_SIZE} \
        --train_text_encoder \
        --mixed_precision="fp16" \
        --use_8bit_adam \
        --gradient_accumulation_steps=1 \
        --learning_rate={LEARNING_RATE} \
        --max_train_steps={MAX_TRAIN_STEPS} \
        --save_sample_prompt="{instance_prompt}" \
        --concepts_list="concepts_list.json"
    """)

# Function for inference
def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
    pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.enable_xformers_memory_efficient_attention()
    g_cuda = torch.Generator(device='cuda').manual_seed(SEED)

    with torch.autocast("cuda"), torch.inference_mode():
        images = pipe(
            prompt,
            height=height,
            width=width,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_samples,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=g_cuda
        ).images

    return images

# Gradio UI function
def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale):
    model_path = OUTPUT_DIR
    prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt
    images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale)
    return images

# Define Gradio interface
def create_gradio_ui():
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier")
                image_upload = gr.File(label="Upload Images", file_count="multiple", type="file")
                finetune_button = gr.Button(value="Fine-Tune Model")
                finetune_output = gr.Textbox(label="Fine-Tuning Output")

            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall")
                negative_prompt = gr.Textbox(label="Negative Prompt", value="")
                num_samples = gr.Number(label="Number of Samples", value=4)
                guidance_scale = gr.Number(label="Guidance Scale", value=8)
                height = gr.Number(label="Height", value=512)
                width = gr.Number(label="Width", value=512)
                num_inference_steps = gr.Slider(label="Steps", value=50)
                generate_button = gr.Button(value="Generate Images")
                gallery = gr.Gallery()

            finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output)
            generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)

    demo.launch()

if __name__ == "__main__":
    create_gradio_ui()