Spaces:
Sleeping
Sleeping
import os | |
import json | |
import shutil | |
from pathlib import Path | |
import torch | |
import gradio as gr | |
from diffusers import StableDiffusionPipeline, DDIMScheduler | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from PIL import Image | |
from torch import autocast | |
# Define necessary paths and variables | |
MODEL_NAME = "runwayml/stable-diffusion-v1-5" | |
OUTPUT_DIR = "/output_model" | |
INSTANCE_PROMPT = "photo of {identifier} person" | |
CLASS_PROMPT = "photo of a person" | |
SEED = 1337 | |
RESOLUTION = 512 | |
TRAIN_BATCH_SIZE = 1 | |
LEARNING_RATE = 1e-6 | |
MAX_TRAIN_STEPS = 800 | |
GUIDANCE_SCALE = 8.0 | |
NUM_INFERENCE_STEPS = 50 | |
# Function to fine-tune the model | |
def fine_tune_model(instance_data_dir, identifier): | |
# Set up paths | |
instance_prompt = INSTANCE_PROMPT.format(identifier=identifier) | |
concepts_list = [ | |
{ | |
"instance_prompt": instance_prompt, | |
"class_prompt": CLASS_PROMPT, | |
"instance_data_dir": instance_data_dir, | |
"class_data_dir": "/sample_data/person" # Placeholder for regularization images | |
} | |
] | |
# Save concepts_list.json | |
with open("concepts_list.json", "w") as f: | |
json.dump(concepts_list, f, indent=4) | |
# Run the training script | |
os.system(f""" | |
python3 train_dreambooth.py \ | |
--pretrained_model_name_or_path={MODEL_NAME} \ | |
--output_dir={OUTPUT_DIR} \ | |
--revision="fp16" \ | |
--with_prior_preservation --prior_loss_weight=1.0 \ | |
--seed={SEED} \ | |
--resolution={RESOLUTION} \ | |
--train_batch_size={TRAIN_BATCH_SIZE} \ | |
--train_text_encoder \ | |
--mixed_precision="fp16" \ | |
--use_8bit_adam \ | |
--gradient_accumulation_steps=1 \ | |
--learning_rate={LEARNING_RATE} \ | |
--max_train_steps={MAX_TRAIN_STEPS} \ | |
--save_sample_prompt="{instance_prompt}" \ | |
--concepts_list="concepts_list.json" | |
""") | |
# Function for inference | |
def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5): | |
pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda") | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_xformers_memory_efficient_attention() | |
g_cuda = torch.Generator(device='cuda').manual_seed(SEED) | |
with torch.autocast("cuda"), torch.inference_mode(): | |
images = pipe( | |
prompt, | |
height=height, | |
width=width, | |
negative_prompt=negative_prompt, | |
num_images_per_prompt=num_samples, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=g_cuda | |
).images | |
return images | |
# Gradio UI function | |
def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale): | |
model_path = OUTPUT_DIR | |
prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt | |
images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale) | |
return images | |
# Define Gradio interface | |
def create_gradio_ui(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier") | |
image_upload = gr.File(label="Upload Images", file_count="multiple", type="file") | |
finetune_button = gr.Button(value="Fine-Tune Model") | |
finetune_output = gr.Textbox(label="Fine-Tuning Output") | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall") | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="") | |
num_samples = gr.Number(label="Number of Samples", value=4) | |
guidance_scale = gr.Number(label="Guidance Scale", value=8) | |
height = gr.Number(label="Height", value=512) | |
width = gr.Number(label="Width", value=512) | |
num_inference_steps = gr.Slider(label="Steps", value=50) | |
generate_button = gr.Button(value="Generate Images") | |
gallery = gr.Gallery() | |
finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output) | |
generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery) | |
demo.launch() | |
if __name__ == "__main__": | |
create_gradio_ui() |