File size: 4,331 Bytes
4198ed7
c5a0203
831b686
4198ed7
 
d3837fc
c5a0203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3837fc
468cc1d
d3837fc
ecedcde
4198ed7
 
c5a0203
4198ed7
 
c5a0203
d3837fc
 
ecedcde
 
 
 
 
 
 
d3837fc
 
8689a3c
c5a0203
4198ed7
c5a0203
 
4198ed7
c5a0203
4198ed7
c5a0203
8689a3c
c5a0203
8689a3c
c5a0203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4198ed7
 
8689a3c
c5a0203
4198ed7
 
8689a3c
4198ed7
 
c5a0203
 
d3837fc
4198ed7
c5a0203
d3837fc
 
c5a0203
 
4198ed7
 
c5a0203
 
4198ed7
c5a0203
 
 
 
40cbb76
4198ed7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe  # Your MGDPipe implementation


# Load models and pipeline
def load_models(pretrained_model_path, device):
    """
    Load the models required for the MGDPipe.
    Args:
        pretrained_model_path (str): Path or Hugging Face identifier for the model.
        device (torch.device): Device to load the models on.

    Returns:
        MGDPipe: Initialized MGDPipe object.
    """
    # Load components of Stable Diffusion
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
    scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    scheduler.set_timesteps(50)

    # Handle torch.hub checkpoint loading for CPU-only environments
    map_location = torch.device("cpu") if device.type == "cpu" else None

    # Load the UNet model and force map_location for state_dict loading
    unet = torch.hub.load(
        repo_or_dir="aimagelab/multimodal-garment-designer",
        source="github",
        model="mgd",
        pretrained=True,
        dataset="dresscode",  # Change to "vitonhd" if needed
    )

    # Ensure the model state dict is mapped correctly to the CPU if needed
    if device.type == "cpu":
        checkpoint_url = unet.config.get("checkpoint")
        if checkpoint_url:
            state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
            unet.load_state_dict(state_dict)

    # Move UNet to the appropriate device
    unet = unet.to(device)

    # Initialize the pipeline
    pipeline = MGDPipe(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
    )

    return pipeline


# Function to preprocess and generate images
def generate_image(pipeline, sketch, prompt, device):
    """
    Generate an image using the MGDPipe.
    Args:
        pipeline (MGDPipe): Initialized MGDPipe object.
        sketch (PIL.Image.Image): Sketch uploaded by the user.
        prompt (str): Text prompt provided by the user.
        device (torch.device): Device for inference.

    Returns:
        PIL.Image.Image: Generated image.
    """
    # Preprocess the sketch
    sketch = sketch.resize((512, 384)).convert("RGB")
    sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)

    # Run the pipeline
    output = pipeline(
        prompt=prompt,
        image=torch.zeros_like(sketch_tensor),  # Placeholder for masked image
        mask_image=torch.ones_like(sketch_tensor),  # Placeholder for mask
        pose_map=torch.zeros((1, 3, 64, 48)).to(device),  # Placeholder pose map
        sketch=sketch_tensor,
        guidance_scale=7.5,
        num_inference_steps=50,
    )

    return output.images[0]


# Streamlit Interface
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")

# User inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")

if st.button("Generate"):
    if uploaded_file and text_prompt:
        st.write("Loading models...")

        # Detect device (CPU or GPU)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pretrained_model_path = "runwayml/stable-diffusion-inpainting"  # Change as required

        # Load the pipeline
        pipeline = load_models(pretrained_model_path, device)

        # Load sketch
        sketch = Image.open(uploaded_file)

        # Generate the image
        st.write("Generating the garment design...")
        generated_image = generate_image(pipeline, sketch, text_prompt, device)

        # Display the result
        st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
    else:
        st.error("Please upload a sketch and enter a text description.")