Spaces:
Sleeping
Sleeping
File size: 3,061 Bytes
8689a3c 4198ed7 831b686 4198ed7 8689a3c 4198ed7 8689a3c 4198ed7 8689a3c 4198ed7 8689a3c 4198ed7 8689a3c 4198ed7 8689a3c 4198ed7 40cbb76 4198ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import torch
import streamlit as st
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
# Function to load models
def load_models(pretrained_model_name_or_path, device):
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
scheduler.set_timesteps(50, device=device)
unet = torch.hub.load(
repo_or_dir="aimagelab/multimodal-garment-designer",
model="mgd",
pretrained=True,
source="github",
)
return tokenizer, text_encoder, vae, scheduler, unet
# Function to generate images
def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
# Preprocess inputs
sketch = sketch.resize((512, 384)).convert("RGB")
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
# Tokenize prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Initialize pipeline
pipeline = MGDPipe(
text_encoder=text_encoder.to(device),
vae=vae.to(device),
unet=unet.to(device),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
# Generate image
pipeline.enable_attention_slicing()
with torch.inference_mode():
outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
return outputs[0]
# Streamlit UI
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")
# User Inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")
# Generate button
if st.button("Generate"):
if uploaded_file and text_prompt:
# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model_path = "your-pretrained-model-path" # Replace with actual model path
tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device)
# Load sketch
sketch = Image.open(uploaded_file)
# Generate image
st.write("Generating the garment design...")
output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device)
# Display output
st.image(output_image, caption="Generated Garment Design", use_column_width=True)
else:
st.error("Please upload a sketch and enter a text description.")
|