|
import streamlit as st |
|
from diffusers import DiffusionPipeline |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
@st.cache_resource |
|
def load_pipeline(): |
|
|
|
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
|
pipe.enable_model_cpu_offload() |
|
return pipe |
|
|
|
pipe = load_pipeline() |
|
|
|
|
|
st.title("Text-to-Image Generation with FLUX.1-schnell") |
|
|
|
|
|
user_prompt = st.text_input("Enter your image prompt", value="A cat holding a sign that says hello world") |
|
|
|
|
|
if st.button("Generate Image"): |
|
if user_prompt: |
|
with st.spinner("Generating image..."): |
|
|
|
image = pipe( |
|
user_prompt, |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
).images[0] |
|
|
|
|
|
image.save("generated_image.png") |
|
st.image(image, caption="Generated Image (FLUX.1-schnell)", use_column_width=True) |
|
else: |
|
st.error("Please enter a valid prompt.") |
|
|