fyp-deploy / app.py
Mairaaa's picture
Update app.py
c5a0203 verified
raw
history blame
3.73 kB
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe # Use your implementation of MGDPipe
# Load models and pipeline
def load_models(pretrained_model_path, device):
"""
Load the models required for the MGDPipe.
Args:
pretrained_model_path (str): Path or Hugging Face identifier for the model.
device (torch.device): Device to load the models on.
Returns:
MGDPipe: Initialized MGDPipe object.
"""
# Load components of Stable Diffusion
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
scheduler.set_timesteps(50)
# Load the UNet model
unet = torch.hub.load(
repo_or_dir="aimagelab/multimodal-garment-designer",
source="github",
model="mgd",
pretrained=True,
dataset="dresscode", # Change to "vitonhd" if needed
).to(device)
# Initialize the pipeline
pipeline = MGDPipe(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
return pipeline
# Function to preprocess and generate images
def generate_image(pipeline, sketch, prompt, device):
"""
Generate an image using the MGDPipe.
Args:
pipeline (MGDPipe): Initialized MGDPipe object.
sketch (PIL.Image.Image): Sketch uploaded by the user.
prompt (str): Text prompt provided by the user.
device (torch.device): Device for inference.
Returns:
PIL.Image.Image: Generated image.
"""
# Preprocess the sketch
sketch = sketch.resize((512, 384)).convert("RGB")
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
# Run the pipeline
output = pipeline(
prompt=prompt,
image=torch.zeros_like(sketch_tensor), # Placeholder for masked image
mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask
pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map
sketch=sketch_tensor,
guidance_scale=7.5,
num_inference_steps=50,
)
return output.images[0]
# Streamlit Interface
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")
# User inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")
if st.button("Generate"):
if uploaded_file and text_prompt:
st.write("Loading models...")
# Load the pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required
pipeline = load_models(pretrained_model_path, device)
# Load sketch
sketch = Image.open(uploaded_file)
# Generate the image
st.write("Generating the garment design...")
generated_image = generate_image(pipeline, sketch, text_prompt, device)
# Display the result
st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
else:
st.error("Please upload a sketch and enter a text description.")