Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from src.mgd_pipelines.mgd_pipe import MGDPipe # Your MGDPipe implementation | |
# Load models and pipeline | |
def load_models(pretrained_model_path, device): | |
""" | |
Load the models required for the MGDPipe. | |
Args: | |
pretrained_model_path (str): Path or Hugging Face identifier for the model. | |
device (torch.device): Device to load the models on. | |
Returns: | |
MGDPipe: Initialized MGDPipe object. | |
""" | |
# Load components of Stable Diffusion | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device) | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) | |
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
scheduler.set_timesteps(50) | |
# Handle torch.hub checkpoint loading for CPU-only environments | |
map_location = torch.device("cpu") if device.type == "cpu" else None | |
# Load the UNet model and force map_location for state_dict loading | |
unet = torch.hub.load( | |
repo_or_dir="aimagelab/multimodal-garment-designer", | |
source="github", | |
model="mgd", | |
pretrained=True, | |
dataset="dresscode", # Change to "vitonhd" if needed | |
) | |
# Ensure the model state dict is mapped correctly to the CPU if needed | |
if device.type == "cpu": | |
checkpoint_url = unet.config.get("checkpoint") | |
if checkpoint_url: | |
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") | |
unet.load_state_dict(state_dict) | |
# Move UNet to the appropriate device | |
unet = unet.to(device) | |
# Initialize the pipeline | |
pipeline = MGDPipe( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
) | |
return pipeline | |
# Function to preprocess and generate images | |
def generate_image(pipeline, sketch, prompt, device): | |
""" | |
Generate an image using the MGDPipe. | |
Args: | |
pipeline (MGDPipe): Initialized MGDPipe object. | |
sketch (PIL.Image.Image): Sketch uploaded by the user. | |
prompt (str): Text prompt provided by the user. | |
device (torch.device): Device for inference. | |
Returns: | |
PIL.Image.Image: Generated image. | |
""" | |
# Preprocess the sketch | |
sketch = sketch.resize((512, 384)).convert("RGB") | |
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device) | |
# Run the pipeline | |
output = pipeline( | |
prompt=prompt, | |
image=torch.zeros_like(sketch_tensor), # Placeholder for masked image | |
mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask | |
pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map | |
sketch=sketch_tensor, | |
guidance_scale=7.5, | |
num_inference_steps=50, | |
) | |
return output.images[0] | |
# Streamlit Interface | |
st.title("Garment Designer") | |
st.write("Upload a sketch and provide a text description to generate garment designs!") | |
# User inputs | |
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"]) | |
text_prompt = st.text_input("Enter a text description for the garment") | |
if st.button("Generate"): | |
if uploaded_file and text_prompt: | |
st.write("Loading models...") | |
# Detect device (CPU or GPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required | |
# Load the pipeline | |
pipeline = load_models(pretrained_model_path, device) | |
# Load sketch | |
sketch = Image.open(uploaded_file) | |
# Generate the image | |
st.write("Generating the garment design...") | |
generated_image = generate_image(pipeline, sketch, text_prompt, device) | |
# Display the result | |
st.image(generated_image, caption="Generated Garment Design", use_column_width=True) | |
else: | |
st.error("Please upload a sketch and enter a text description.") | |