import streamlit as st import torch from PIL import Image from io import BytesIO from diffusers import DDIMScheduler, AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer from src.mgd_pipelines.mgd_pipe import MGDPipe # Initialize the model and other components @st.cache_resource def load_model(): # Define your model loading logic device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") print("VAE model loaded successfully.") except OSError as e: print(f"Error loading the model: {e}") tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder") unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True) scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler") pipe = MGDPipe( text_encoder=text_encoder, vae=vae, unet=unet.to(vae.dtype), tokenizer=tokenizer, scheduler=scheduler, ).to(device) return pipe pipe = load_model() def generate_images(pipe, text_input=None, sketch=None): # Generate images from text or sketch or both images = [] if text_input: prompt = [text_input] images.extend(pipe(prompt=prompt)) if sketch: sketch_image = Image.open(sketch).convert("RGB") images.extend(pipe(sketch=sketch_image)) return images # Streamlit UI st.title("Sketch & Text-based Image Generation") st.write("Generate images based on rough sketches, text input, or both.") option = st.radio("Select Input Type", ("Sketch", "Text", "Both")) if option in ["Sketch", "Both"]: sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"]) if option in ["Text", "Both"]: text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate") if st.button("Generate"): if option == "Sketch" and not sketch_file: st.error("Please upload a sketch.") elif option == "Text" and not text_input: st.error("Please provide text input.") else: # Generate images based on user input with st.spinner("Generating images..."): sketches = BytesIO(sketch_file.read()) if sketch_file else None images = generate_images(pipe, text_input=text_input, sketch=sketches) # Display results for i, img in enumerate(images): st.image(img, caption=f"Generated Image {i+1}")