tahirsher commited on
Commit
36f8896
1 Parent(s): 1946e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -3,35 +3,30 @@ from diffusers import DiffusionPipeline
3
  from PIL import Image
4
  import torch
5
 
6
- # Load the FLUX model
7
  @st.cache_resource
8
  def load_pipeline():
9
- # Using the 'black-forest-labs/FLUX.1-schnell' model
10
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
11
- pipe.enable_model_cpu_offload() # Offload to CPU to save memory
12
  return pipe
13
 
14
  pipe = load_pipeline()
15
 
16
  # Streamlit app
17
- st.title("Text-to-Image Generation with FLUX.1-schnell")
18
 
19
  # User input for prompt
20
- user_prompt = st.text_input("Enter your image prompt", value="A cat holding a sign that says hello world")
21
 
22
  # Button to generate the image
23
  if st.button("Generate Image"):
24
  if user_prompt:
25
  with st.spinner("Generating image..."):
26
- # Generate the image using the FLUX model
27
- image = pipe(
28
- user_prompt,
29
- guidance_scale=0.0, # No guidance
30
- num_inference_steps=4, # Number of steps for faster generation
31
- ).images[0]
32
 
33
- # Save and display the image
34
- image.save("generated_image.png")
35
- st.image(image, caption="Generated Image (FLUX.1-schnell)", use_column_width=True)
36
  else:
37
  st.error("Please enter a valid prompt.")
 
 
3
  from PIL import Image
4
  import torch
5
 
6
+ # Load the diffusion pipeline model
7
  @st.cache_resource
8
  def load_pipeline():
9
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
10
+ pipe.load_lora_weights("Melonie/text_to_image_finetuned")
 
11
  return pipe
12
 
13
  pipe = load_pipeline()
14
 
15
  # Streamlit app
16
+ st.title("Text-to-Image Generation App")
17
 
18
  # User input for prompt
19
+ user_prompt = st.text_input("Enter your image prompt", value="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k")
20
 
21
  # Button to generate the image
22
  if st.button("Generate Image"):
23
  if user_prompt:
24
  with st.spinner("Generating image..."):
25
+ # Generate the image
26
+ image = pipe(user_prompt).images[0]
 
 
 
 
27
 
28
+ # Display the generated image
29
+ st.image(image, caption="Generated Image", use_column_width=True)
 
30
  else:
31
  st.error("Please enter a valid prompt.")
32
+