Spaces:
Sleeping
Sleeping
File size: 4,586 Bytes
6ae61a6 40cbb76 831b686 40cbb76 831b686 6ae61a6 40cbb76 6ae61a6 40cbb76 74d4e67 5b2f1b4 74d4e67 5b2f1b4 74d4e67 5b2f1b4 74d4e67 5b2f1b4 74d4e67 5b2f1b4 74d4e67 5b2f1b4 74d4e67 5b2f1b4 458eaee f26e6e9 5b2f1b4 74d4e67 5b2f1b4 74d4e67 f26e6e9 74d4e67 458eaee 40cbb76 831b686 40cbb76 f26e6e9 40cbb76 831b686 40cbb76 f26e6e9 40cbb76 e27303d f26e6e9 40cbb76 f26e6e9 40cbb76 f26e6e9 40cbb76 f26e6e9 e27303d 40cbb76 e27303d 40cbb76 f26e6e9 40cbb76 f26e6e9 74d4e67 f26e6e9 74d4e67 f26e6e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
# Initialize the model and other components
@st.cache_resource
def load_model():
try:
# Define your model loading logic
print("Initializing model loading...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device selected: {device}")
# Load the VAE
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
print("VAE loaded successfully.")
# Load the tokenizer
print("Loading tokenizer...")
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
print("Tokenizer loaded successfully.")
# Load the text encoder
print("Loading text encoder...")
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
print("Text encoder loaded successfully.")
# Load the UNet model
print("Loading UNet...")
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
print("UNet loaded successfully.")
# Load the scheduler
print("Loading scheduler...")
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
print("Scheduler loaded successfully.")
# Initialize the pipeline
print("Initializing pipeline...")
pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
pipe.enable_attention_slicing()
print("Pipeline initialized successfully.")
return pipe
except Exception as e:
print(f"Error loading the model: {e}")
return None
pipe = load_model()
def generate_images(pipe, text_input=None, sketch=None):
# Generate images from text or sketch or both
images = []
try:
if pipe:
# Generate from text
if text_input:
print(f"Generating image from text: {text_input}")
images.append(pipe(prompt=[text_input]))
# Generate from sketch
if sketch:
print("Generating image from sketch.")
sketch_image = Image.open(sketch).convert("RGB")
images.append(pipe(sketch=sketch_image))
except Exception as e:
print(f"Error during image generation: {e}")
return images
# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")
# Input options
option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
sketch_file = None
text_input = None
# Get sketch input
if option in ["Sketch", "Both"]:
sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
# Get text input
if option in ["Text", "Both"]:
text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
# Generate button
if st.button("Generate"):
# Ensure the model is loaded
if pipe is None:
st.error("Model failed to load. Please restart the application.")
st.stop()
# Validate inputs
sketches = BytesIO(sketch_file.read()) if sketch_file else None
if option == "Sketch" and not sketch_file:
st.error("Please upload a sketch.")
elif option == "Text" and not text_input:
st.error("Please provide text input.")
elif option == "Both" and not (sketch_file or text_input):
st.error("Please provide both a sketch and a text prompt.")
else:
# Generate images
with st.spinner("Generating images..."):
images = generate_images(pipe, text_input=text_input, sketch=sketches)
# Display results
if images:
for i, img in enumerate(images):
if isinstance(img, torch.Tensor): # Convert tensor to image
img = img.squeeze().permute(1, 2, 0).cpu().numpy()
img = Image.fromarray((img * 255).astype("uint8"))
st.image(img, caption=f"Generated Image {i+1}")
else:
st.error("Failed to generate images. Please check the inputs or model configuration.")
|