fyp-deploy / app.py
Mairaaa's picture
Update app.py
d2f7af8 verified
raw
history blame
4.05 kB
import os
import pandas as np
import torch
import streamlit as st
from PIL import Image
from accelerate import Accelerator
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
from src.utils.set_seeds import set_seed
from src.utils.image_from_pipe import generate_images_from_mgd_pipe
from src.datasets.dresscode import DressCodeDataset
# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_START_METHOD"] = "thread"
# Function to process inputs and run inference
def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"):
# Initialize accelerator
accelerator = Accelerator(mixed_precision=mixed_precision)
device = accelerator.device
# Load models and datasets
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler")
# Load UNet (assumed pretrained)
unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True)
# Freeze VAE and text encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Set seed for reproducibility
if seed is not None:
set_seed(seed)
# Load appropriate dataset
category = [category]
test_dataset = DressCodeDataset(
dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384)
)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
# Move models to the device
text_encoder.to(device)
vae.to(device)
unet.to(device).eval()
# Handle sketch and text inputs
if sketch_image is not None:
# Process the sketch (resize, normalize, etc.)
sketch_image = sketch_image.resize((512, 384))
sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device)
# Select pipeline (disentangled if required)
val_pipe = MGDPipeDisentangled(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=val_scheduler,
).to(device)
val_pipe.enable_attention_slicing()
# Generate image
generated_images = generate_images_from_mgd_pipe(
test_dataloader=test_dataloader,
pipe=val_pipe,
guidance_scale=7.5,
seed=seed,
sketch_image=sketch_tensor if sketch_image is not None else None,
prompt=prompt
)
return generated_images[0] # Assuming single image output
# Streamlit UI
st.title("Fashion Image Generator")
st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.")
# Upload a sketch image
uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"])
# Text input for prompt
prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns")
# Input options
category = st.text_input("Enter category (optional):", "dresses")
seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None)
precision = st.selectbox("Select precision:", ["fp16", "fp32"])
# Show uploaded sketch image
if uploaded_sketch is not None:
sketch_image = Image.open(uploaded_sketch)
st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True)
# Button to generate image
if st.button("Generate Image"):
with st.spinner("Generating image..."):
# Run inference with sketch or prompt (or both)
result_image = run_inference(prompt, sketch_image, category, seed, precision)
st.image(result_image, caption="Generated Image", use_column_width=True)