fyp-deploy / app.py
Mairaaa's picture
Update app.py
4198ed7 verified
raw
history blame
3.06 kB
import os
import torch
import streamlit as st
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
# Function to load models
def load_models(pretrained_model_name_or_path, device):
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
scheduler.set_timesteps(50, device=device)
unet = torch.hub.load(
repo_or_dir="aimagelab/multimodal-garment-designer",
model="mgd",
pretrained=True,
source="github",
)
return tokenizer, text_encoder, vae, scheduler, unet
# Function to generate images
def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
# Preprocess inputs
sketch = sketch.resize((512, 384)).convert("RGB")
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
# Tokenize prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Initialize pipeline
pipeline = MGDPipe(
text_encoder=text_encoder.to(device),
vae=vae.to(device),
unet=unet.to(device),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
# Generate image
pipeline.enable_attention_slicing()
with torch.inference_mode():
outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
return outputs[0]
# Streamlit UI
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")
# User Inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")
# Generate button
if st.button("Generate"):
if uploaded_file and text_prompt:
# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model_path = "your-pretrained-model-path" # Replace with actual model path
tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device)
# Load sketch
sketch = Image.open(uploaded_file)
# Generate image
st.write("Generating the garment design...")
output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device)
# Display output
st.image(output_image, caption="Generated Garment Design", use_column_width=True)
else:
st.error("Please upload a sketch and enter a text description.")