fyp-deploy / app.py
Mairaaa's picture
Update app.py
ecedcde verified
raw
history blame
4.33 kB
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe # Your MGDPipe implementation
# Load models and pipeline
def load_models(pretrained_model_path, device):
"""
Load the models required for the MGDPipe.
Args:
pretrained_model_path (str): Path or Hugging Face identifier for the model.
device (torch.device): Device to load the models on.
Returns:
MGDPipe: Initialized MGDPipe object.
"""
# Load components of Stable Diffusion
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
scheduler.set_timesteps(50)
# Handle torch.hub checkpoint loading for CPU-only environments
map_location = torch.device("cpu") if device.type == "cpu" else None
# Load the UNet model and force map_location for state_dict loading
unet = torch.hub.load(
repo_or_dir="aimagelab/multimodal-garment-designer",
source="github",
model="mgd",
pretrained=True,
dataset="dresscode", # Change to "vitonhd" if needed
)
# Ensure the model state dict is mapped correctly to the CPU if needed
if device.type == "cpu":
checkpoint_url = unet.config.get("checkpoint")
if checkpoint_url:
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
unet.load_state_dict(state_dict)
# Move UNet to the appropriate device
unet = unet.to(device)
# Initialize the pipeline
pipeline = MGDPipe(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
return pipeline
# Function to preprocess and generate images
def generate_image(pipeline, sketch, prompt, device):
"""
Generate an image using the MGDPipe.
Args:
pipeline (MGDPipe): Initialized MGDPipe object.
sketch (PIL.Image.Image): Sketch uploaded by the user.
prompt (str): Text prompt provided by the user.
device (torch.device): Device for inference.
Returns:
PIL.Image.Image: Generated image.
"""
# Preprocess the sketch
sketch = sketch.resize((512, 384)).convert("RGB")
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
# Run the pipeline
output = pipeline(
prompt=prompt,
image=torch.zeros_like(sketch_tensor), # Placeholder for masked image
mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask
pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map
sketch=sketch_tensor,
guidance_scale=7.5,
num_inference_steps=50,
)
return output.images[0]
# Streamlit Interface
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")
# User inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")
if st.button("Generate"):
if uploaded_file and text_prompt:
st.write("Loading models...")
# Detect device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required
# Load the pipeline
pipeline = load_models(pretrained_model_path, device)
# Load sketch
sketch = Image.open(uploaded_file)
# Generate the image
st.write("Generating the garment design...")
generated_image = generate_image(pipeline, sketch, text_prompt, device)
# Display the result
st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
else:
st.error("Please upload a sketch and enter a text description.")