Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import streamlit as st | |
from PIL import Image | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from src.mgd_pipelines.mgd_pipe import MGDPipe | |
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
# Function to load models | |
def load_models(pretrained_model_name_or_path, device): | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") | |
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") | |
scheduler.set_timesteps(50, device=device) | |
unet = torch.hub.load( | |
repo_or_dir="aimagelab/multimodal-garment-designer", | |
model="mgd", | |
pretrained=True, | |
source="github", | |
) | |
return tokenizer, text_encoder, vae, scheduler, unet | |
# Function to generate images | |
def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device): | |
# Preprocess inputs | |
sketch = sketch.resize((512, 384)).convert("RGB") | |
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device) | |
# Tokenize prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Initialize pipeline | |
pipeline = MGDPipe( | |
text_encoder=text_encoder.to(device), | |
vae=vae.to(device), | |
unet=unet.to(device), | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
).to(device) | |
# Generate image | |
pipeline.enable_attention_slicing() | |
with torch.inference_mode(): | |
outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5) | |
return outputs[0] | |
# Streamlit UI | |
st.title("Garment Designer") | |
st.write("Upload a sketch and provide a text description to generate garment designs!") | |
# User Inputs | |
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"]) | |
text_prompt = st.text_input("Enter a text description for the garment") | |
# Generate button | |
if st.button("Generate"): | |
if uploaded_file and text_prompt: | |
# Load models | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pretrained_model_path = "your-pretrained-model-path" # Replace with actual model path | |
tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device) | |
# Load sketch | |
sketch = Image.open(uploaded_file) | |
# Generate image | |
st.write("Generating the garment design...") | |
output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device) | |
# Display output | |
st.image(output_image, caption="Generated Garment Design", use_column_width=True) | |
else: | |
st.error("Please upload a sketch and enter a text description.") | |