Magix / main.py
Singularity666's picture
Create main.py
b92dd65 verified
raw
history blame
4.65 kB
import os
import json
import shutil
from pathlib import Path
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
from torch import autocast
# Define necessary paths and variables
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/output_model"
INSTANCE_PROMPT = "photo of {identifier} person"
CLASS_PROMPT = "photo of a person"
SEED = 1337
RESOLUTION = 512
TRAIN_BATCH_SIZE = 1
LEARNING_RATE = 1e-6
MAX_TRAIN_STEPS = 800
GUIDANCE_SCALE = 8.0
NUM_INFERENCE_STEPS = 50
# Function to fine-tune the model
def fine_tune_model(instance_data_dir, identifier):
# Set up paths
instance_prompt = INSTANCE_PROMPT.format(identifier=identifier)
concepts_list = [
{
"instance_prompt": instance_prompt,
"class_prompt": CLASS_PROMPT,
"instance_data_dir": instance_data_dir,
"class_data_dir": "/sample_data/person" # Placeholder for regularization images
}
]
# Save concepts_list.json
with open("concepts_list.json", "w") as f:
json.dump(concepts_list, f, indent=4)
# Run the training script
os.system(f"""
python3 train_dreambooth.py \
--pretrained_model_name_or_path={MODEL_NAME} \
--output_dir={OUTPUT_DIR} \
--revision="fp16" \
--with_prior_preservation --prior_loss_weight=1.0 \
--seed={SEED} \
--resolution={RESOLUTION} \
--train_batch_size={TRAIN_BATCH_SIZE} \
--train_text_encoder \
--mixed_precision="fp16" \
--use_8bit_adam \
--gradient_accumulation_steps=1 \
--learning_rate={LEARNING_RATE} \
--max_train_steps={MAX_TRAIN_STEPS} \
--save_sample_prompt="{instance_prompt}" \
--concepts_list="concepts_list.json"
""")
# Function for inference
def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = torch.Generator(device='cuda').manual_seed(SEED)
with torch.autocast("cuda"), torch.inference_mode():
images = pipe(
prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=g_cuda
).images
return images
# Gradio UI function
def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale):
model_path = OUTPUT_DIR
prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt
images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale)
return images
# Define Gradio interface
def create_gradio_ui():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier")
image_upload = gr.File(label="Upload Images", file_count="multiple", type="file")
finetune_button = gr.Button(value="Fine-Tune Model")
finetune_output = gr.Textbox(label="Fine-Tuning Output")
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall")
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
num_samples = gr.Number(label="Number of Samples", value=4)
guidance_scale = gr.Number(label="Guidance Scale", value=8)
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=50)
generate_button = gr.Button(value="Generate Images")
gallery = gr.Gallery()
finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output)
generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
demo.launch()
if __name__ == "__main__":
create_gradio_ui()