tahirsher commited on
Commit
cd9d635
·
verified ·
1 Parent(s): c2302e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -1,29 +1,39 @@
1
  import streamlit as st
2
- from diffusers import DiffusionPipeline
 
 
3
 
4
- # Load the diffusion pipeline model
5
  @st.cache_resource
6
  def load_pipeline():
7
- # Using the 'black-forest-labs/FLUX.1-schnell' model for fast generation
8
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
 
9
  return pipe
10
 
11
  pipe = load_pipeline()
12
 
13
  # Streamlit app
14
- st.title("Text-to-Image Generation App (FLUX Model)")
15
 
16
  # User input for prompt
17
- user_prompt = st.text_input("Enter your image prompt", value="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k")
18
 
19
  # Button to generate the image
20
  if st.button("Generate Image"):
21
  if user_prompt:
22
  with st.spinner("Generating image..."):
23
  # Generate the image using the FLUX model
24
- image = pipe(user_prompt).images[0]
 
 
 
 
 
 
25
 
26
- # Display the generated image
27
- st.image(image, caption="Generated Image (FLUX Model)", use_column_width=True)
 
28
  else:
29
  st.error("Please enter a valid prompt.")
 
1
  import streamlit as st
2
+ from diffusers import FluxPipeline
3
+ from PIL import Image
4
+ import torch
5
 
6
+ # Load the Flux model
7
  @st.cache_resource
8
  def load_pipeline():
9
+ # Using the 'black-forest-labs/FLUX.1-schnell' model
10
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
11
+ pipe.enable_model_cpu_offload() # Offload to CPU to save memory
12
  return pipe
13
 
14
  pipe = load_pipeline()
15
 
16
  # Streamlit app
17
+ st.title("Text-to-Image Generation with FLUX.1-schnell")
18
 
19
  # User input for prompt
20
+ user_prompt = st.text_input("Enter your image prompt", value="A cat holding a sign that says hello world")
21
 
22
  # Button to generate the image
23
  if st.button("Generate Image"):
24
  if user_prompt:
25
  with st.spinner("Generating image..."):
26
  # Generate the image using the FLUX model
27
+ image = pipe(
28
+ user_prompt,
29
+ guidance_scale=0.0, # No guidance
30
+ num_inference_steps=4, # Number of steps for faster generation
31
+ max_sequence_length=256,
32
+ generator=torch.Generator("cpu").manual_seed(0) # Ensure reproducibility
33
+ ).images[0]
34
 
35
+ # Save and display the image
36
+ image.save("generated_image.png")
37
+ st.image(image, caption="Generated Image (FLUX.1-schnell)", use_column_width=True)
38
  else:
39
  st.error("Please enter a valid prompt.")