Spaces:
Sleeping
Sleeping
import os | |
import pandas as np | |
import torch | |
import streamlit as st | |
from PIL import Image | |
from accelerate import Accelerator | |
from diffusers import DDIMScheduler, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from src.mgd_pipelines.mgd_pipe import MGDPipe | |
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
from src.utils.set_seeds import set_seed | |
from src.utils.image_from_pipe import generate_images_from_mgd_pipe | |
from src.datasets.dresscode import DressCodeDataset | |
# Set environment variables | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
os.environ["WANDB_START_METHOD"] = "thread" | |
# Function to process inputs and run inference | |
def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"): | |
# Initialize accelerator | |
accelerator = Accelerator(mixed_precision=mixed_precision) | |
device = accelerator.device | |
# Load models and datasets | |
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae") | |
val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler") | |
# Load UNet (assumed pretrained) | |
unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True) | |
# Freeze VAE and text encoder | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
# Set seed for reproducibility | |
if seed is not None: | |
set_seed(seed) | |
# Load appropriate dataset | |
category = [category] | |
test_dataset = DressCodeDataset( | |
dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384) | |
) | |
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) | |
# Move models to the device | |
text_encoder.to(device) | |
vae.to(device) | |
unet.to(device).eval() | |
# Handle sketch and text inputs | |
if sketch_image is not None: | |
# Process the sketch (resize, normalize, etc.) | |
sketch_image = sketch_image.resize((512, 384)) | |
sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device) | |
# Select pipeline (disentangled if required) | |
val_pipe = MGDPipeDisentangled( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet, | |
tokenizer=tokenizer, | |
scheduler=val_scheduler, | |
).to(device) | |
val_pipe.enable_attention_slicing() | |
# Generate image | |
generated_images = generate_images_from_mgd_pipe( | |
test_dataloader=test_dataloader, | |
pipe=val_pipe, | |
guidance_scale=7.5, | |
seed=seed, | |
sketch_image=sketch_tensor if sketch_image is not None else None, | |
prompt=prompt | |
) | |
return generated_images[0] # Assuming single image output | |
# Streamlit UI | |
st.title("Fashion Image Generator") | |
st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.") | |
# Upload a sketch image | |
uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"]) | |
# Text input for prompt | |
prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns") | |
# Input options | |
category = st.text_input("Enter category (optional):", "dresses") | |
seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None) | |
precision = st.selectbox("Select precision:", ["fp16", "fp32"]) | |
# Show uploaded sketch image | |
if uploaded_sketch is not None: | |
sketch_image = Image.open(uploaded_sketch) | |
st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True) | |
# Button to generate image | |
if st.button("Generate Image"): | |
with st.spinner("Generating image..."): | |
# Run inference with sketch or prompt (or both) | |
result_image = run_inference(prompt, sketch_image, category, seed, precision) | |
st.image(result_image, caption="Generated Image", use_column_width=True) | |