File size: 4,051 Bytes
af894bb
831b686
6ae61a6
 
831b686
 
 
6ae61a6
 
 
 
831b686
d2f7af8
6ae61a6
831b686
6ae61a6
 
 
831b686
 
 
 
6ae61a6
 
831b686
 
 
 
 
6ae61a6
831b686
 
6ae61a6
831b686
 
 
6ae61a6
831b686
 
 
458eaee
831b686
 
 
 
6ae61a6
458eaee
831b686
 
 
 
 
 
458eaee
831b686
 
 
 
 
 
 
 
6ae61a6
 
831b686
6ae61a6
 
 
458eaee
831b686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import pandas as np
import torch
import streamlit as st
from PIL import Image
from accelerate import Accelerator
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
from src.utils.set_seeds import set_seed
from src.utils.image_from_pipe import generate_images_from_mgd_pipe
from src.datasets.dresscode import DressCodeDataset

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_START_METHOD"] = "thread"

# Function to process inputs and run inference
def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"):
    # Initialize accelerator
    accelerator = Accelerator(mixed_precision=mixed_precision)
    device = accelerator.device

    # Load models and datasets
    tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
    val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler")

    # Load UNet (assumed pretrained)
    unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True)

    # Freeze VAE and text encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Set seed for reproducibility
    if seed is not None:
        set_seed(seed)

    # Load appropriate dataset
    category = [category]
    test_dataset = DressCodeDataset(
        dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384)
    )

    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

    # Move models to the device
    text_encoder.to(device)
    vae.to(device)
    unet.to(device).eval()

    # Handle sketch and text inputs
    if sketch_image is not None:
        # Process the sketch (resize, normalize, etc.)
        sketch_image = sketch_image.resize((512, 384))
        sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device)

    # Select pipeline (disentangled if required)
    val_pipe = MGDPipeDisentangled(
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer,
        scheduler=val_scheduler,
    ).to(device)

    val_pipe.enable_attention_slicing()

    # Generate image
    generated_images = generate_images_from_mgd_pipe(
        test_dataloader=test_dataloader,
        pipe=val_pipe,
        guidance_scale=7.5,
        seed=seed,
        sketch_image=sketch_tensor if sketch_image is not None else None,
        prompt=prompt
    )

    return generated_images[0]  # Assuming single image output

# Streamlit UI
st.title("Fashion Image Generator")
st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.")

# Upload a sketch image
uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"])

# Text input for prompt
prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns")

# Input options
category = st.text_input("Enter category (optional):", "dresses")
seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None)
precision = st.selectbox("Select precision:", ["fp16", "fp32"])

# Show uploaded sketch image
if uploaded_sketch is not None:
    sketch_image = Image.open(uploaded_sketch)
    st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True)

# Button to generate image
if st.button("Generate Image"):
    with st.spinner("Generating image..."):
        # Run inference with sketch or prompt (or both)
        result_image = run_inference(prompt, sketch_image, category, seed, precision)
        st.image(result_image, caption="Generated Image", use_column_width=True)