Spaces:
Running
Running
File size: 2,554 Bytes
6ae61a6 40cbb76 831b686 40cbb76 831b686 6ae61a6 40cbb76 6ae61a6 40cbb76 831b686 40cbb76 458eaee 40cbb76 6ae61a6 40cbb76 6ae61a6 40cbb76 6ae61a6 40cbb76 458eaee 40cbb76 831b686 40cbb76 831b686 40cbb76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
# Initialize the model and other components
@st.cache_resource
def load_model():
# Define your model loading logic
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
return pipe
pipe = load_model()
def generate_images(pipe, text_input=None, sketch=None):
# Generate images from text or sketch or both
images = []
if text_input:
prompt = [text_input]
images.extend(pipe(prompt=prompt))
if sketch:
sketch_image = Image.open(sketch).convert("RGB")
images.extend(pipe(sketch=sketch_image))
return images
# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")
option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
if option in ["Sketch", "Both"]:
sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
if option in ["Text", "Both"]:
text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
if st.button("Generate"):
if option == "Sketch" and not sketch_file:
st.error("Please upload a sketch.")
elif option == "Text" and not text_input:
st.error("Please provide text input.")
else:
# Generate images based on user input
with st.spinner("Generating images..."):
sketches = BytesIO(sketch_file.read()) if sketch_file else None
images = generate_images(pipe, text_input=text_input, sketch=sketches)
# Display results
for i, img in enumerate(images):
st.image(img, caption=f"Generated Image {i+1}")
|