File size: 4,586 Bytes
6ae61a6
40cbb76
831b686
40cbb76
831b686
6ae61a6
40cbb76
6ae61a6
40cbb76
 
 
74d4e67
 
5b2f1b4
74d4e67
5b2f1b4
74d4e67
5b2f1b4
 
 
 
 
 
 
74d4e67
5b2f1b4
 
 
 
74d4e67
5b2f1b4
 
 
 
74d4e67
5b2f1b4
 
 
 
74d4e67
5b2f1b4
458eaee
f26e6e9
5b2f1b4
74d4e67
 
 
 
 
 
 
5b2f1b4
 
74d4e67
f26e6e9
74d4e67
 
458eaee
40cbb76
831b686
40cbb76
 
 
f26e6e9
 
 
 
 
 
 
 
 
 
 
 
 
 
40cbb76
831b686
 
40cbb76
 
 
f26e6e9
40cbb76
 
e27303d
 
 
f26e6e9
40cbb76
 
 
f26e6e9
40cbb76
 
 
f26e6e9
40cbb76
f26e6e9
 
 
 
 
 
e27303d
 
40cbb76
 
 
 
e27303d
 
40cbb76
f26e6e9
40cbb76
 
 
f26e6e9
74d4e67
 
f26e6e9
 
 
74d4e67
 
f26e6e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe

# Initialize the model and other components
@st.cache_resource
def load_model():
    try:
        # Define your model loading logic
        print("Initializing model loading...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Device selected: {device}")
        
        # Load the VAE
        print("Loading VAE...")
        vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
        print("VAE loaded successfully.")

        # Load the tokenizer
        print("Loading tokenizer...")
        tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
        print("Tokenizer loaded successfully.")

        # Load the text encoder
        print("Loading text encoder...")
        text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
        print("Text encoder loaded successfully.")

        # Load the UNet model
        print("Loading UNet...")
        unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
        print("UNet loaded successfully.")

        # Load the scheduler
        print("Loading scheduler...")
        scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
        print("Scheduler loaded successfully.")

        # Initialize the pipeline
        print("Initializing pipeline...")
        pipe = MGDPipe(
            text_encoder=text_encoder,
            vae=vae,
            unet=unet.to(vae.dtype),
            tokenizer=tokenizer,
            scheduler=scheduler,
        ).to(device)
        pipe.enable_attention_slicing()
        print("Pipeline initialized successfully.")
        return pipe
    except Exception as e:
        print(f"Error loading the model: {e}")
        return None

pipe = load_model()

def generate_images(pipe, text_input=None, sketch=None):
    # Generate images from text or sketch or both
    images = []
    try:
        if pipe:
            # Generate from text
            if text_input:
                print(f"Generating image from text: {text_input}")
                images.append(pipe(prompt=[text_input]))

            # Generate from sketch
            if sketch:
                print("Generating image from sketch.")
                sketch_image = Image.open(sketch).convert("RGB")
                images.append(pipe(sketch=sketch_image))
    except Exception as e:
        print(f"Error during image generation: {e}")
    return images

# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")

# Input options
option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))

sketch_file = None
text_input = None

# Get sketch input
if option in ["Sketch", "Both"]:
    sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])

# Get text input
if option in ["Text", "Both"]:
    text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")

# Generate button
if st.button("Generate"):
    # Ensure the model is loaded
    if pipe is None:
        st.error("Model failed to load. Please restart the application.")
        st.stop()

    # Validate inputs
    sketches = BytesIO(sketch_file.read()) if sketch_file else None
    
    if option == "Sketch" and not sketch_file:
        st.error("Please upload a sketch.")
    elif option == "Text" and not text_input:
        st.error("Please provide text input.")
    elif option == "Both" and not (sketch_file or text_input):
        st.error("Please provide both a sketch and a text prompt.")
    else:
        # Generate images
        with st.spinner("Generating images..."):
            images = generate_images(pipe, text_input=text_input, sketch=sketches)

        # Display results
        if images:
            for i, img in enumerate(images):
                if isinstance(img, torch.Tensor):  # Convert tensor to image
                    img = img.squeeze().permute(1, 2, 0).cpu().numpy()
                    img = Image.fromarray((img * 255).astype("uint8"))
                st.image(img, caption=f"Generated Image {i+1}")
        else:
            st.error("Failed to generate images. Please check the inputs or model configuration.")