File size: 3,725 Bytes
4198ed7
c5a0203
831b686
4198ed7
 
c5a0203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4198ed7
 
c5a0203
4198ed7
 
c5a0203
 
8689a3c
c5a0203
4198ed7
c5a0203
 
4198ed7
c5a0203
4198ed7
c5a0203
8689a3c
c5a0203
8689a3c
c5a0203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4198ed7
 
8689a3c
c5a0203
4198ed7
 
8689a3c
4198ed7
 
c5a0203
 
 
4198ed7
c5a0203
 
 
4198ed7
 
c5a0203
 
4198ed7
c5a0203
 
 
 
40cbb76
4198ed7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe  # Use your implementation of MGDPipe


# Load models and pipeline
def load_models(pretrained_model_path, device):
    """
    Load the models required for the MGDPipe.
    Args:
        pretrained_model_path (str): Path or Hugging Face identifier for the model.
        device (torch.device): Device to load the models on.

    Returns:
        MGDPipe: Initialized MGDPipe object.
    """
    # Load components of Stable Diffusion
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
    scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    scheduler.set_timesteps(50)

    # Load the UNet model
    unet = torch.hub.load(
        repo_or_dir="aimagelab/multimodal-garment-designer",
        source="github",
        model="mgd",
        pretrained=True,
        dataset="dresscode",  # Change to "vitonhd" if needed
    ).to(device)

    # Initialize the pipeline
    pipeline = MGDPipe(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
    )

    return pipeline


# Function to preprocess and generate images
def generate_image(pipeline, sketch, prompt, device):
    """
    Generate an image using the MGDPipe.
    Args:
        pipeline (MGDPipe): Initialized MGDPipe object.
        sketch (PIL.Image.Image): Sketch uploaded by the user.
        prompt (str): Text prompt provided by the user.
        device (torch.device): Device for inference.

    Returns:
        PIL.Image.Image: Generated image.
    """
    # Preprocess the sketch
    sketch = sketch.resize((512, 384)).convert("RGB")
    sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)

    # Run the pipeline
    output = pipeline(
        prompt=prompt,
        image=torch.zeros_like(sketch_tensor),  # Placeholder for masked image
        mask_image=torch.ones_like(sketch_tensor),  # Placeholder for mask
        pose_map=torch.zeros((1, 3, 64, 48)).to(device),  # Placeholder pose map
        sketch=sketch_tensor,
        guidance_scale=7.5,
        num_inference_steps=50,
    )

    return output.images[0]


# Streamlit Interface
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")

# User inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")

if st.button("Generate"):
    if uploaded_file and text_prompt:
        st.write("Loading models...")

        # Load the pipeline
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pretrained_model_path = "runwayml/stable-diffusion-inpainting"  # Change as required
        pipeline = load_models(pretrained_model_path, device)

        # Load sketch
        sketch = Image.open(uploaded_file)

        # Generate the image
        st.write("Generating the garment design...")
        generated_image = generate_image(pipeline, sketch, text_prompt, device)

        # Display the result
        st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
    else:
        st.error("Please upload a sketch and enter a text description.")