editx / app.py
Singularity666's picture
Create app.py
aa4292f verified
raw
history blame
2.76 kB
import gradio as gr
from main import DreamboothApp
app = DreamboothApp(model_path="stable_diffusion_weights")
def train(instance_images, instance_prompt, num_class_images, max_train_steps):
app.train(instance_data_dir="instance_data",
class_data_dir="class_data",
instance_prompt=instance_prompt,
class_prompt="photo of a person",
num_class_images=num_class_images,
max_train_steps=max_train_steps)
return "Training completed. Model is ready for inference."
def inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed):
app.load_model()
images = app.inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed)
return images
with gr.Blocks() as demo:
gr.Markdown("# Stable Diffusion Dreambooth")
with gr.Tab("Training"):
with gr.Row():
instance_images = gr.File(label="Upload Instance Images (5-10 images recommended)", file_count="multiple")
with gr.Column():
instance_prompt = gr.Textbox(label="Instance Prompt", placeholder="Enter the prompt for your instance images")
num_class_images = gr.Number(label="Number of Class Images", value=50)
max_train_steps = gr.Number(label="Maximum Training Steps", value=800)
train_button = gr.Button("Train Model")
train_output = gr.Textbox(label="Training Output")
train_button.click(train, inputs=[instance_images, instance_prompt, num_class_images, max_train_steps], outputs=train_output)
with gr.Tab("Inference"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here (optional)")
with gr.Row():
num_samples = gr.Number(label="Number of Samples", value=1)
guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
with gr.Row():
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=50)
seed = gr.Number(label="Seed (optional)", value=0)
generate_button = gr.Button("Generate Images")
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
generate_button.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed], outputs=gallery)
demo.launch()