fyp-deploy / app.py
Mairaaa's picture
Update app.py
ecedcde verified
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe # Your MGDPipe implementation
# Load models and pipeline
def load_models(pretrained_model_path, device):
"""
Load the models required for the MGDPipe.
Args:
pretrained_model_path (str): Path or Hugging Face identifier for the model.
device (torch.device): Device to load the models on.
Returns:
MGDPipe: Initialized MGDPipe object.
"""
# Load components of Stable Diffusion
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
scheduler.set_timesteps(50)
# Handle torch.hub checkpoint loading for CPU-only environments
map_location = torch.device("cpu") if device.type == "cpu" else None
# Load the UNet model and force map_location for state_dict loading
unet = torch.hub.load(
repo_or_dir="aimagelab/multimodal-garment-designer",
source="github",
model="mgd",
pretrained=True,
dataset="dresscode", # Change to "vitonhd" if needed
)
# Ensure the model state dict is mapped correctly to the CPU if needed
if device.type == "cpu":
checkpoint_url = unet.config.get("checkpoint")
if checkpoint_url:
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
unet.load_state_dict(state_dict)
# Move UNet to the appropriate device
unet = unet.to(device)
# Initialize the pipeline
pipeline = MGDPipe(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
return pipeline
# Function to preprocess and generate images
def generate_image(pipeline, sketch, prompt, device):
"""
Generate an image using the MGDPipe.
Args:
pipeline (MGDPipe): Initialized MGDPipe object.
sketch (PIL.Image.Image): Sketch uploaded by the user.
prompt (str): Text prompt provided by the user.
device (torch.device): Device for inference.
Returns:
PIL.Image.Image: Generated image.
"""
# Preprocess the sketch
sketch = sketch.resize((512, 384)).convert("RGB")
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
# Run the pipeline
output = pipeline(
prompt=prompt,
image=torch.zeros_like(sketch_tensor), # Placeholder for masked image
mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask
pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map
sketch=sketch_tensor,
guidance_scale=7.5,
num_inference_steps=50,
)
return output.images[0]
# Streamlit Interface
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")
# User inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")
if st.button("Generate"):
if uploaded_file and text_prompt:
st.write("Loading models...")
# Detect device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required
# Load the pipeline
pipeline = load_models(pretrained_model_path, device)
# Load sketch
sketch = Image.open(uploaded_file)
# Generate the image
st.write("Generating the garment design...")
generated_image = generate_image(pipeline, sketch, text_prompt, device)
# Display the result
st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
else:
st.error("Please upload a sketch and enter a text description.")