File size: 3,061 Bytes
8689a3c
4198ed7
 
831b686
4198ed7
 
 
 
8689a3c
4198ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8689a3c
4198ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8689a3c
4198ed7
 
 
 
 
 
8689a3c
4198ed7
 
 
8689a3c
4198ed7
 
 
8689a3c
4198ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40cbb76
4198ed7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import streamlit as st
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled

# Function to load models
def load_models(pretrained_model_name_or_path, device):
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
    scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
    scheduler.set_timesteps(50, device=device)
    
    unet = torch.hub.load(
        repo_or_dir="aimagelab/multimodal-garment-designer",
        model="mgd",
        pretrained=True,
        source="github",
    )
    return tokenizer, text_encoder, vae, scheduler, unet

# Function to generate images
def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
    # Preprocess inputs
    sketch = sketch.resize((512, 384)).convert("RGB")
    sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
    
    # Tokenize prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Initialize pipeline
    pipeline = MGDPipe(
        text_encoder=text_encoder.to(device),
        vae=vae.to(device),
        unet=unet.to(device),
        tokenizer=tokenizer,
        scheduler=scheduler,
    ).to(device)

    # Generate image
    pipeline.enable_attention_slicing()
    with torch.inference_mode():
        outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
    
    return outputs[0]

# Streamlit UI
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")

# User Inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")

# Generate button
if st.button("Generate"):
    if uploaded_file and text_prompt:
        # Load models
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pretrained_model_path = "your-pretrained-model-path"  # Replace with actual model path
        tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device)
        
        # Load sketch
        sketch = Image.open(uploaded_file)
        
        # Generate image
        st.write("Generating the garment design...")
        output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device)
        
        # Display output
        st.image(output_image, caption="Generated Garment Design", use_column_width=True)
    else:
        st.error("Please upload a sketch and enter a text description.")