Spaces:
Sleeping
Sleeping
File size: 4,653 Bytes
b92dd65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import json
import shutil
from pathlib import Path
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
from torch import autocast
# Define necessary paths and variables
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/output_model"
INSTANCE_PROMPT = "photo of {identifier} person"
CLASS_PROMPT = "photo of a person"
SEED = 1337
RESOLUTION = 512
TRAIN_BATCH_SIZE = 1
LEARNING_RATE = 1e-6
MAX_TRAIN_STEPS = 800
GUIDANCE_SCALE = 8.0
NUM_INFERENCE_STEPS = 50
# Function to fine-tune the model
def fine_tune_model(instance_data_dir, identifier):
# Set up paths
instance_prompt = INSTANCE_PROMPT.format(identifier=identifier)
concepts_list = [
{
"instance_prompt": instance_prompt,
"class_prompt": CLASS_PROMPT,
"instance_data_dir": instance_data_dir,
"class_data_dir": "/sample_data/person" # Placeholder for regularization images
}
]
# Save concepts_list.json
with open("concepts_list.json", "w") as f:
json.dump(concepts_list, f, indent=4)
# Run the training script
os.system(f"""
python3 train_dreambooth.py \
--pretrained_model_name_or_path={MODEL_NAME} \
--output_dir={OUTPUT_DIR} \
--revision="fp16" \
--with_prior_preservation --prior_loss_weight=1.0 \
--seed={SEED} \
--resolution={RESOLUTION} \
--train_batch_size={TRAIN_BATCH_SIZE} \
--train_text_encoder \
--mixed_precision="fp16" \
--use_8bit_adam \
--gradient_accumulation_steps=1 \
--learning_rate={LEARNING_RATE} \
--max_train_steps={MAX_TRAIN_STEPS} \
--save_sample_prompt="{instance_prompt}" \
--concepts_list="concepts_list.json"
""")
# Function for inference
def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = torch.Generator(device='cuda').manual_seed(SEED)
with torch.autocast("cuda"), torch.inference_mode():
images = pipe(
prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=g_cuda
).images
return images
# Gradio UI function
def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale):
model_path = OUTPUT_DIR
prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt
images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale)
return images
# Define Gradio interface
def create_gradio_ui():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier")
image_upload = gr.File(label="Upload Images", file_count="multiple", type="file")
finetune_button = gr.Button(value="Fine-Tune Model")
finetune_output = gr.Textbox(label="Fine-Tuning Output")
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall")
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
num_samples = gr.Number(label="Number of Samples", value=4)
guidance_scale = gr.Number(label="Guidance Scale", value=8)
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=50)
generate_button = gr.Button(value="Generate Images")
gallery = gr.Gallery()
finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output)
generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
demo.launch()
if __name__ == "__main__":
create_gradio_ui() |