import streamlit as st import torch from PIL import Image from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler from src.mgd_pipelines.mgd_pipe import MGDPipe # Your MGDPipe implementation # Load models and pipeline def load_models(pretrained_model_path, device): """ Load the models required for the MGDPipe. Args: pretrained_model_path (str): Path or Hugging Face identifier for the model. device (torch.device): Device to load the models on. Returns: MGDPipe: Initialized MGDPipe object. """ # Load components of Stable Diffusion tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device) vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") scheduler.set_timesteps(50) # Handle torch.hub checkpoint loading for CPU-only environments map_location = torch.device("cpu") if device.type == "cpu" else None # Load the UNet model and force map_location for state_dict loading unet = torch.hub.load( repo_or_dir="aimagelab/multimodal-garment-designer", source="github", model="mgd", pretrained=True, dataset="dresscode", # Change to "vitonhd" if needed ) # Ensure the model state dict is mapped correctly to the CPU if needed if device.type == "cpu": checkpoint_url = unet.config.get("checkpoint") if checkpoint_url: state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") unet.load_state_dict(state_dict) # Move UNet to the appropriate device unet = unet.to(device) # Initialize the pipeline pipeline = MGDPipe( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) return pipeline # Function to preprocess and generate images def generate_image(pipeline, sketch, prompt, device): """ Generate an image using the MGDPipe. Args: pipeline (MGDPipe): Initialized MGDPipe object. sketch (PIL.Image.Image): Sketch uploaded by the user. prompt (str): Text prompt provided by the user. device (torch.device): Device for inference. Returns: PIL.Image.Image: Generated image. """ # Preprocess the sketch sketch = sketch.resize((512, 384)).convert("RGB") sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device) # Run the pipeline output = pipeline( prompt=prompt, image=torch.zeros_like(sketch_tensor), # Placeholder for masked image mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map sketch=sketch_tensor, guidance_scale=7.5, num_inference_steps=50, ) return output.images[0] # Streamlit Interface st.title("Garment Designer") st.write("Upload a sketch and provide a text description to generate garment designs!") # User inputs uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"]) text_prompt = st.text_input("Enter a text description for the garment") if st.button("Generate"): if uploaded_file and text_prompt: st.write("Loading models...") # Detect device (CPU or GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required # Load the pipeline pipeline = load_models(pretrained_model_path, device) # Load sketch sketch = Image.open(uploaded_file) # Generate the image st.write("Generating the garment design...") generated_image = generate_image(pipeline, sketch, text_prompt, device) # Display the result st.image(generated_image, caption="Generated Garment Design", use_column_width=True) else: st.error("Please upload a sketch and enter a text description.")