Waseem7711 commited on
Commit
903b3c1
·
verified ·
1 Parent(s): e11a04e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from diffusers import DiffusionPipeline
3
  import torch
4
  from PIL import Image
5
 
@@ -15,12 +15,15 @@ generate_button = st.sidebar.button("Generate Image")
15
  # Load the pipeline when the app starts
16
  @st.cache_resource
17
  def load_pipeline():
18
- pipe = DiffusionPipeline.from_pretrained(
 
19
  "runwayml/stable-diffusion-v1-5",
20
- torch_dtype=torch.float16
 
21
  )
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
- return pipe.to(device)
 
24
 
25
  pipe = load_pipeline()
26
 
@@ -28,19 +31,23 @@ pipe = load_pipeline()
28
  if generate_button:
29
  st.write(f"### Prompt: {prompt}")
30
  with st.spinner("Generating image... Please wait."):
31
- # Generate the image
32
- image = pipe(prompt).images[0]
33
-
34
- # Display the generated image
35
- st.image(image, caption="Generated Image", use_column_width=True)
36
-
37
- # Option to download the image
38
- img_path = "generated_image.png"
39
- image.save(img_path)
40
- with open(img_path, "rb") as img_file:
41
- st.download_button(
42
- label="Download Image",
43
- data=img_file,
44
- file_name="generated_image.png",
45
- mime="image/png"
46
- )
 
 
 
 
 
1
  import streamlit as st
2
+ from diffusers import StableDiffusionPipeline
3
  import torch
4
  from PIL import Image
5
 
 
15
  # Load the pipeline when the app starts
16
  @st.cache_resource
17
  def load_pipeline():
18
+ # Use a smaller model or a more efficient pipeline
19
+ pipe = StableDiffusionPipeline.from_pretrained(
20
  "runwayml/stable-diffusion-v1-5",
21
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
22
+ revision="fp16" if torch.cuda.is_available() else None
23
  )
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ pipe = pipe.to(device)
26
+ return pipe
27
 
28
  pipe = load_pipeline()
29
 
 
31
  if generate_button:
32
  st.write(f"### Prompt: {prompt}")
33
  with st.spinner("Generating image... Please wait."):
34
+ try:
35
+ # Generate the image
36
+ with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
37
+ image = pipe(prompt).images[0]
38
+
39
+ # Display the generated image
40
+ st.image(image, caption="Generated Image", use_column_width=True)
41
+
42
+ # Option to download the image
43
+ img_path = "generated_image.png"
44
+ image.save(img_path)
45
+ with open(img_path, "rb") as img_file:
46
+ st.download_button(
47
+ label="Download Image",
48
+ data=img_file,
49
+ file_name="generated_image.png",
50
+ mime="image/png"
51
+ )
52
+ except Exception as e:
53
+ st.error(f"An error occurred: {e}")