Singularity666 commited on
Commit
b7a7c79
·
verified ·
1 Parent(s): 48dd315

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -41
app.py CHANGED
@@ -1,51 +1,75 @@
1
  import gradio as gr
2
- from main import DreamboothApp
 
 
 
 
3
 
4
- app = DreamboothApp(model_path="stable_diffusion_weights")
 
5
 
6
- def train(instance_images, instance_prompt, num_class_images, max_train_steps):
7
- app.train(instance_data_dir="instance_data",
8
- class_data_dir="class_data",
9
- instance_prompt=instance_prompt,
10
- class_prompt="photo of a person",
11
- num_class_images=num_class_images,
12
- max_train_steps=max_train_steps)
13
- return "Training completed. Model is ready for inference."
 
 
 
 
 
 
 
 
14
 
15
- def inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed):
16
- app.load_model()
17
- images = app.inference(prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed)
18
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- with gr.Blocks() as demo:
21
- gr.Markdown("# Stable Diffusion Dreambooth")
22
- with gr.Tab("Training"):
23
- with gr.Row():
24
- instance_images = gr.File(label="Upload Instance Images (5-10 images recommended)", file_count="multiple")
25
- with gr.Column():
26
- instance_prompt = gr.Textbox(label="Instance Prompt", placeholder="Enter the prompt for your instance images")
27
- num_class_images = gr.Number(label="Number of Class Images", value=50)
28
- max_train_steps = gr.Number(label="Maximum Training Steps", value=800)
29
- train_button = gr.Button("Train Model")
30
- train_output = gr.Textbox(label="Training Output")
31
- train_button.click(train, inputs=[instance_images, instance_prompt, num_class_images, max_train_steps], outputs=train_output)
32
-
33
- with gr.Tab("Inference"):
34
- with gr.Row():
35
- with gr.Column():
36
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
37
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here (optional)")
38
- with gr.Row():
39
  num_samples = gr.Number(label="Number of Samples", value=1)
40
  guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
41
- with gr.Row():
42
  height = gr.Number(label="Height", value=512)
43
  width = gr.Number(label="Width", value=512)
44
- num_inference_steps = gr.Slider(label="Steps", value=50)
45
- seed = gr.Number(label="Seed (optional)", value=0)
46
- generate_button = gr.Button("Generate Images")
47
- with gr.Column():
48
- gallery = gr.Gallery(label="Generated Images")
49
- generate_button.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale, seed], outputs=gallery)
 
50
 
51
- demo.launch()
 
 
1
  import gradio as gr
2
+ import os
3
+ import shutil
4
+ from main import fine_tune_model
5
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
6
+ import torch
7
 
8
+ MODEL_NAME = "runwayml/stable-diffusion-v1-5"
9
+ OUTPUT_DIR = "/home/user/app/stable_diffusion_weights/custom_model"
10
 
11
+ def fine_tune(instance_prompt, image1, image2=None):
12
+ instance_data_dir = "/home/user/app/instance_images"
13
+
14
+ try:
15
+ if os.path.exists(instance_data_dir):
16
+ shutil.rmtree(instance_data_dir)
17
+ os.makedirs(instance_data_dir, exist_ok=True)
18
+
19
+ image1.save(os.path.join(instance_data_dir, "instance_0.png"))
20
+ if image2 is not None:
21
+ image2.save(os.path.join(instance_data_dir, "instance_1.png"))
22
+
23
+ fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR)
24
+ return "Model fine-tuning complete."
25
+ except Exception as e:
26
+ return str(e)
27
 
28
+ def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale):
29
+ try:
30
+ if not os.path.exists(OUTPUT_DIR):
31
+ return "The model path does not exist."
32
+
33
+ pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float32).to("cpu")
34
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
35
+
36
+ with torch.autocast("cpu"), torch.inference_mode():
37
+ images = pipe(
38
+ prompt, height=height, width=width, num_images_per_prompt=num_samples,
39
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
40
+ ).images
41
+
42
+ return images
43
+ except Exception as e:
44
+ return str(e)
45
 
46
+ def gradio_app():
47
+ with gr.Blocks() as demo:
48
+ with gr.Tab("Fine-Tune Model"):
49
+ with gr.Row():
50
+ with gr.Column():
51
+ instance_prompt = gr.Textbox(label="Instance Prompt")
52
+ image1 = gr.Image(label="Upload Image 1", type="pil")
53
+ image2 = gr.Image(label="Upload Image 2 (Optional)", type="pil")
54
+ fine_tune_button = gr.Button("Fine-Tune Model")
55
+ output_text = gr.Textbox(label="Output")
56
+ fine_tune_button.click(fine_tune, inputs=[instance_prompt, image1, image2], outputs=output_text)
57
+
58
+ with gr.Tab("Generate Images"):
59
+ with gr.Row():
60
+ with gr.Column():
61
+ prompt = gr.Textbox(label="Prompt")
 
 
 
62
  num_samples = gr.Number(label="Number of Samples", value=1)
63
  guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
 
64
  height = gr.Number(label="Height", value=512)
65
  width = gr.Number(label="Width", value=512)
66
+ num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100)
67
+ generate_button = gr.Button("Generate Images")
68
+ with gr.Column():
69
+ gallery = gr.Gallery(label="Generated Images")
70
+ generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
71
+
72
+ demo.launch()
73
 
74
+ if __name__ == "__main__":
75
+ gradio_app()