Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from src.mgd_pipelines.mgd_pipe import MGDPipe # Use your implementation of MGDPipe | |
# Load models and pipeline | |
def load_models(pretrained_model_path, device): | |
""" | |
Load the models required for the MGDPipe. | |
Args: | |
pretrained_model_path (str): Path or Hugging Face identifier for the model. | |
device (torch.device): Device to load the models on. | |
Returns: | |
MGDPipe: Initialized MGDPipe object. | |
""" | |
# Load components of Stable Diffusion | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device) | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) | |
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
scheduler.set_timesteps(50) | |
# Load the UNet model | |
unet = torch.hub.load( | |
repo_or_dir="aimagelab/multimodal-garment-designer", | |
source="github", | |
model="mgd", | |
pretrained=True, | |
dataset="dresscode", # Change to "vitonhd" if needed | |
).to(device) | |
# Initialize the pipeline | |
pipeline = MGDPipe( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
) | |
return pipeline | |
# Function to preprocess and generate images | |
def generate_image(pipeline, sketch, prompt, device): | |
""" | |
Generate an image using the MGDPipe. | |
Args: | |
pipeline (MGDPipe): Initialized MGDPipe object. | |
sketch (PIL.Image.Image): Sketch uploaded by the user. | |
prompt (str): Text prompt provided by the user. | |
device (torch.device): Device for inference. | |
Returns: | |
PIL.Image.Image: Generated image. | |
""" | |
# Preprocess the sketch | |
sketch = sketch.resize((512, 384)).convert("RGB") | |
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device) | |
# Run the pipeline | |
output = pipeline( | |
prompt=prompt, | |
image=torch.zeros_like(sketch_tensor), # Placeholder for masked image | |
mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask | |
pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map | |
sketch=sketch_tensor, | |
guidance_scale=7.5, | |
num_inference_steps=50, | |
) | |
return output.images[0] | |
# Streamlit Interface | |
st.title("Garment Designer") | |
st.write("Upload a sketch and provide a text description to generate garment designs!") | |
# User inputs | |
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"]) | |
text_prompt = st.text_input("Enter a text description for the garment") | |
if st.button("Generate"): | |
if uploaded_file and text_prompt: | |
st.write("Loading models...") | |
# Load the pipeline | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required | |
pipeline = load_models(pretrained_model_path, device) | |
# Load sketch | |
sketch = Image.open(uploaded_file) | |
# Generate the image | |
st.write("Generating the garment design...") | |
generated_image = generate_image(pipeline, sketch, text_prompt, device) | |
# Display the result | |
st.image(generated_image, caption="Generated Garment Design", use_column_width=True) | |
else: | |
st.error("Please upload a sketch and enter a text description.") | |