Singularity666 commited on
Commit
aa4292f
·
verified ·
1 Parent(s): d953dcd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from main import DreamboothApp
3
+
4
+ app = DreamboothApp(model_path="stable_diffusion_weights")
5
+
6
+ def train(instance_images, instance_prompt, num_class_images, max_train_steps):
7
+ app.train(instance_data_dir="instance_data",
8
+ class_data_dir="class_data",
9
+ instance_prompt=instance_prompt,
10
+ class_prompt="photo of a person",
11
+ num_class_images=num_class_images,
12
+ max_train_steps=max_train_steps)
13
+ return "Training completed. Model is ready for inference."
14
+
15
+ def inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed):
16
+ app.load_model()
17
+ images = app.inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed)
18
+ return images
19
+
20
+ with gr.Blocks() as demo:
21
+ gr.Markdown("# Stable Diffusion Dreambooth")
22
+ with gr.Tab("Training"):
23
+ with gr.Row():
24
+ instance_images = gr.File(label="Upload Instance Images (5-10 images recommended)", file_count="multiple")
25
+ with gr.Column():
26
+ instance_prompt = gr.Textbox(label="Instance Prompt", placeholder="Enter the prompt for your instance images")
27
+ num_class_images = gr.Number(label="Number of Class Images", value=50)
28
+ max_train_steps = gr.Number(label="Maximum Training Steps", value=800)
29
+ train_button = gr.Button("Train Model")
30
+ train_output = gr.Textbox(label="Training Output")
31
+ train_button.click(train, inputs=[instance_images, instance_prompt, num_class_images, max_train_steps], outputs=train_output)
32
+
33
+ with gr.Tab("Inference"):
34
+ with gr.Row():
35
+ with gr.Column():
36
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
37
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here (optional)")
38
+ with gr.Row():
39
+ num_samples = gr.Number(label="Number of Samples", value=1)
40
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
41
+ with gr.Row():
42
+ height = gr.Number(label="Height", value=512)
43
+ width = gr.Number(label="Width", value=512)
44
+ num_inference_steps = gr.Slider(label="Steps", value=50)
45
+ seed = gr.Number(label="Seed (optional)", value=0)
46
+ generate_button = gr.Button("Generate Images")
47
+ with gr.Column():
48
+ gallery = gr.Gallery(label="Generated Images")
49
+ generate_button.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed], outputs=gallery)
50
+
51
+ demo.launch()