Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
from diffusers import DDIMScheduler, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from src.mgd_pipelines.mgd_pipe import MGDPipe | |
# Initialize the model and other components | |
def load_model(): | |
try: | |
# Define your model loading logic | |
print("Initializing model loading...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Device selected: {device}") | |
# Load the VAE | |
print("Loading VAE...") | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
print("VAE loaded successfully.") | |
# Load the tokenizer | |
print("Loading tokenizer...") | |
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer") | |
print("Tokenizer loaded successfully.") | |
# Load the text encoder | |
print("Loading text encoder...") | |
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder") | |
print("Text encoder loaded successfully.") | |
# Load the UNet model | |
print("Loading UNet...") | |
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True) | |
print("UNet loaded successfully.") | |
# Load the scheduler | |
print("Loading scheduler...") | |
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler") | |
print("Scheduler loaded successfully.") | |
# Initialize the pipeline | |
print("Initializing pipeline...") | |
pipe = MGDPipe( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet.to(vae.dtype), | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
).to(device) | |
pipe.enable_attention_slicing() | |
print("Pipeline initialized successfully.") | |
return pipe | |
except Exception as e: | |
print(f"Error loading the model: {e}") | |
return None | |
pipe = load_model() | |
def generate_images(pipe, text_input=None, sketch=None): | |
# Generate images from text or sketch or both | |
images = [] | |
try: | |
if pipe: | |
# Generate from text | |
if text_input: | |
print(f"Generating image from text: {text_input}") | |
images.append(pipe(prompt=[text_input])) | |
# Generate from sketch | |
if sketch: | |
print("Generating image from sketch.") | |
sketch_image = Image.open(sketch).convert("RGB") | |
images.append(pipe(sketch=sketch_image)) | |
except Exception as e: | |
print(f"Error during image generation: {e}") | |
return images | |
# Streamlit UI | |
st.title("Sketch & Text-based Image Generation") | |
st.write("Generate images based on rough sketches, text input, or both.") | |
# Input options | |
option = st.radio("Select Input Type", ("Sketch", "Text", "Both")) | |
sketch_file = None | |
text_input = None | |
# Get sketch input | |
if option in ["Sketch", "Both"]: | |
sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"]) | |
# Get text input | |
if option in ["Text", "Both"]: | |
text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate") | |
# Generate button | |
if st.button("Generate"): | |
# Ensure the model is loaded | |
if pipe is None: | |
st.error("Model failed to load. Please restart the application.") | |
st.stop() | |
# Validate inputs | |
sketches = BytesIO(sketch_file.read()) if sketch_file else None | |
if option == "Sketch" and not sketch_file: | |
st.error("Please upload a sketch.") | |
elif option == "Text" and not text_input: | |
st.error("Please provide text input.") | |
elif option == "Both" and not (sketch_file or text_input): | |
st.error("Please provide both a sketch and a text prompt.") | |
else: | |
# Generate images | |
with st.spinner("Generating images..."): | |
images = generate_images(pipe, text_input=text_input, sketch=sketches) | |
# Display results | |
if images: | |
for i, img in enumerate(images): | |
if isinstance(img, torch.Tensor): # Convert tensor to image | |
img = img.squeeze().permute(1, 2, 0).cpu().numpy() | |
img = Image.fromarray((img * 255).astype("uint8")) | |
st.image(img, caption=f"Generated Image {i+1}") | |
else: | |
st.error("Failed to generate images. Please check the inputs or model configuration.") | |