bguisard commited on
Commit
c0c58ec
·
1 Parent(s): 5be066f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import jax
 
3
  from flax.jax_utils import replicate
4
  from flax.training.common_utils import shard
5
- from diffusers import FlaxStableDiffusionPipeline
6
 
7
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
8
  "bguisard/stable-diffusion-nano",
@@ -13,11 +13,11 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
13
  rng = jax.random.PRNGKey(int(prng_seed))
14
  rng = jax.random.split(rng, jax.device_count())
15
  p_params = replicate(pipeline_params)
16
-
17
  num_samples = 1
18
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
19
  prompt_ids = shard(prompt_ids)
20
-
21
  images = pipeline(
22
  prompt_ids=prompt_ids,
23
  params=p_params,
@@ -30,7 +30,7 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
30
 
31
  images = images.reshape((num_samples,) + images.shape[-3:])
32
  images = pipeline.numpy_to_pil(images)
33
- return images
34
 
35
 
36
  prompt_input = gr.inputs.Textbox(
@@ -44,7 +44,7 @@ seed_input = gr.inputs.Number(default=0, label="Seed")
44
  app = gr.Interface(
45
  fn=generate_image,
46
  inputs=[prompt_input, inf_steps_input, seed_input],
47
- outputs=gr.Image(shape=(128, 128)),
48
  title="Stable Diffusion Nano",
49
  description=(
50
  "Based on stable diffusion and fine-tuned on 128x128 images, "
 
1
  import gradio as gr
2
  import jax
3
+ from diffusers import FlaxStableDiffusionPipeline
4
  from flax.jax_utils import replicate
5
  from flax.training.common_utils import shard
 
6
 
7
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
8
  "bguisard/stable-diffusion-nano",
 
13
  rng = jax.random.PRNGKey(int(prng_seed))
14
  rng = jax.random.split(rng, jax.device_count())
15
  p_params = replicate(pipeline_params)
16
+
17
  num_samples = 1
18
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
19
  prompt_ids = shard(prompt_ids)
20
+
21
  images = pipeline(
22
  prompt_ids=prompt_ids,
23
  params=p_params,
 
30
 
31
  images = images.reshape((num_samples,) + images.shape[-3:])
32
  images = pipeline.numpy_to_pil(images)
33
+ return images[0]
34
 
35
 
36
  prompt_input = gr.inputs.Textbox(
 
44
  app = gr.Interface(
45
  fn=generate_image,
46
  inputs=[prompt_input, inf_steps_input, seed_input],
47
+ outputs="image",
48
  title="Stable Diffusion Nano",
49
  description=(
50
  "Based on stable diffusion and fine-tuned on 128x128 images, "