Spaces:
Runtime error
Runtime error
import gradio as gr | |
from main import DreamboothApp | |
app = DreamboothApp(model_path="stable_diffusion_weights") | |
def train(instance_images, instance_prompt, num_class_images, max_train_steps): | |
app.train(instance_data_dir="instance_data", | |
class_data_dir="class_data", | |
instance_prompt=instance_prompt, | |
class_prompt="photo of a person", | |
num_class_images=num_class_images, | |
max_train_steps=max_train_steps) | |
return "Training completed. Model is ready for inference." | |
def inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed): | |
app.load_model() | |
images = app.inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed) | |
return images | |
with gr.Blocks() as demo: | |
gr.Markdown("# Stable Diffusion Dreambooth") | |
with gr.Tab("Training"): | |
with gr.Row(): | |
instance_images = gr.File(label="Upload Instance Images (5-10 images recommended)", file_count="multiple") | |
with gr.Column(): | |
instance_prompt = gr.Textbox(label="Instance Prompt", placeholder="Enter the prompt for your instance images") | |
num_class_images = gr.Number(label="Number of Class Images", value=50) | |
max_train_steps = gr.Number(label="Maximum Training Steps", value=800) | |
train_button = gr.Button("Train Model") | |
train_output = gr.Textbox(label="Training Output") | |
train_button.click(train, inputs=[instance_images, instance_prompt, num_class_images, max_train_steps], outputs=train_output) | |
with gr.Tab("Inference"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here") | |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here (optional)") | |
with gr.Row(): | |
num_samples = gr.Number(label="Number of Samples", value=1) | |
guidance_scale = gr.Number(label="Guidance Scale", value=7.5) | |
with gr.Row(): | |
height = gr.Number(label="Height", value=512) | |
width = gr.Number(label="Width", value=512) | |
num_inference_steps = gr.Slider(label="Steps", value=50) | |
seed = gr.Number(label="Seed (optional)", value=0) | |
generate_button = gr.Button("Generate Images") | |
with gr.Column(): | |
gallery = gr.Gallery(label="Generated Images") | |
generate_button.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed], outputs=gallery) | |
demo.launch() |