TRaw commited on
Commit
0305586
·
verified ·
1 Parent(s): db41edc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import mlx.core as mx
5
+ from stable_diffusion import StableDiffusion
6
+
7
+ def generate_images(prompt, n_images=4, steps=50, cfg=7.5, negative_prompt="", n_rows=1):
8
+ sd = StableDiffusion()
9
+
10
+ # Generate the latent vectors using diffusion
11
+ latents = sd.generate_latents(
12
+ prompt,
13
+ n_images=n_images,
14
+ cfg_weight=cfg,
15
+ num_steps=steps,
16
+ negative_text=negative_prompt,
17
+ )
18
+ for x_t in latents:
19
+ mx.simplify(x_t)
20
+ mx.simplify(x_t)
21
+ mx.eval(x_t)
22
+
23
+ # Decode them into images
24
+ decoded = []
25
+ for i in range(0, n_images):
26
+ decoded_img = sd.decode(x_t[i:i+1])
27
+ mx.eval(decoded_img)
28
+ decoded.append(decoded_img)
29
+
30
+ # Arrange them on a grid
31
+ x = mx.concatenate(decoded, axis=0)
32
+ x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
33
+ B, H, W, C = x.shape
34
+ x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
35
+ x = x.reshape(n_rows * H, B // n_rows * W, C)
36
+ x = (x * 255).astype(mx.uint8)
37
+
38
+ # Convert to PIL Image
39
+ return Image.fromarray(x.__array__())
40
+
41
+ iface = gr.Interface(
42
+ fn=generate_images,
43
+ inputs=[
44
+ gr.Textbox(label="Prompt"),
45
+ gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Images"),
46
+ gr.Slider(minimum=20, maximum=100, step=1, value=50, label="Steps"),
47
+ gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=7.5, label="CFG Weight"),
48
+ gr.Textbox(default="", label="Negative Prompt"),
49
+ gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Rows")
50
+ ],
51
+ outputs="image",
52
+ title="Stable Diffusion Image Generator",
53
+ description="Generate images from a textual prompt using Stable Diffusion"
54
+ )
55
+
56
+ iface.launch()