File size: 2,554 Bytes
6ae61a6
40cbb76
831b686
40cbb76
831b686
6ae61a6
40cbb76
6ae61a6
40cbb76
 
 
 
 
 
831b686
 
40cbb76
 
458eaee
40cbb76
6ae61a6
 
40cbb76
6ae61a6
40cbb76
6ae61a6
40cbb76
458eaee
40cbb76
831b686
40cbb76
 
 
 
 
 
 
 
 
 
831b686
 
40cbb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe

# Initialize the model and other components
@st.cache_resource
def load_model():
    # Define your model loading logic
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
    tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
    unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
    scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")

    pipe = MGDPipe(
        text_encoder=text_encoder,
        vae=vae,
        unet=unet.to(vae.dtype),
        tokenizer=tokenizer,
        scheduler=scheduler,
    ).to(device)
    return pipe

pipe = load_model()

def generate_images(pipe, text_input=None, sketch=None):
    # Generate images from text or sketch or both
    images = []
    if text_input:
        prompt = [text_input]
        images.extend(pipe(prompt=prompt))
    if sketch:
        sketch_image = Image.open(sketch).convert("RGB")
        images.extend(pipe(sketch=sketch_image))
    return images

# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")

option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))

if option in ["Sketch", "Both"]:
    sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])

if option in ["Text", "Both"]:
    text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")

if st.button("Generate"):
    if option == "Sketch" and not sketch_file:
        st.error("Please upload a sketch.")
    elif option == "Text" and not text_input:
        st.error("Please provide text input.")
    else:
        # Generate images based on user input
        with st.spinner("Generating images..."):
            sketches = BytesIO(sketch_file.read()) if sketch_file else None
            images = generate_images(pipe, text_input=text_input, sketch=sketches)

        # Display results
        for i, img in enumerate(images):
            st.image(img, caption=f"Generated Image {i+1}")