import os import json import shutil from pathlib import Path import torch import gradio as gr from diffusers import StableDiffusionPipeline, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer from PIL import Image from torch import autocast # Define necessary paths and variables MODEL_NAME = "runwayml/stable-diffusion-v1-5" OUTPUT_DIR = "/output_model" INSTANCE_PROMPT = "photo of {identifier} person" CLASS_PROMPT = "photo of a person" SEED = 1337 RESOLUTION = 512 TRAIN_BATCH_SIZE = 1 LEARNING_RATE = 1e-6 MAX_TRAIN_STEPS = 800 GUIDANCE_SCALE = 8.0 NUM_INFERENCE_STEPS = 50 # Function to fine-tune the model def fine_tune_model(instance_data_dir, identifier): # Set up paths instance_prompt = INSTANCE_PROMPT.format(identifier=identifier) concepts_list = [ { "instance_prompt": instance_prompt, "class_prompt": CLASS_PROMPT, "instance_data_dir": instance_data_dir, "class_data_dir": "/sample_data/person" # Placeholder for regularization images } ] # Save concepts_list.json with open("concepts_list.json", "w") as f: json.dump(concepts_list, f, indent=4) # Run the training script os.system(f""" python3 train_dreambooth.py \ --pretrained_model_name_or_path={MODEL_NAME} \ --output_dir={OUTPUT_DIR} \ --revision="fp16" \ --with_prior_preservation --prior_loss_weight=1.0 \ --seed={SEED} \ --resolution={RESOLUTION} \ --train_batch_size={TRAIN_BATCH_SIZE} \ --train_text_encoder \ --mixed_precision="fp16" \ --use_8bit_adam \ --gradient_accumulation_steps=1 \ --learning_rate={LEARNING_RATE} \ --max_train_steps={MAX_TRAIN_STEPS} \ --save_sample_prompt="{instance_prompt}" \ --concepts_list="concepts_list.json" """) # Function for inference def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5): pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.enable_xformers_memory_efficient_attention() g_cuda = torch.Generator(device='cuda').manual_seed(SEED) with torch.autocast("cuda"), torch.inference_mode(): images = pipe( prompt, height=height, width=width, negative_prompt=negative_prompt, num_images_per_prompt=num_samples, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda ).images return images # Gradio UI function def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale): model_path = OUTPUT_DIR prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale) return images # Define Gradio interface def create_gradio_ui(): with gr.Blocks() as demo: with gr.Row(): with gr.Column(): identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier") image_upload = gr.File(label="Upload Images", file_count="multiple", type="file") finetune_button = gr.Button(value="Fine-Tune Model") finetune_output = gr.Textbox(label="Fine-Tuning Output") with gr.Column(): prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall") negative_prompt = gr.Textbox(label="Negative Prompt", value="") num_samples = gr.Number(label="Number of Samples", value=4) guidance_scale = gr.Number(label="Guidance Scale", value=8) height = gr.Number(label="Height", value=512) width = gr.Number(label="Width", value=512) num_inference_steps = gr.Slider(label="Steps", value=50) generate_button = gr.Button(value="Generate Images") gallery = gr.Gallery() finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output) generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery) demo.launch() if __name__ == "__main__": create_gradio_ui()