fyp-deploy / app.py
Mairaaa's picture
Update app.py
5b2f1b4 verified
raw
history blame
4.59 kB
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
# Initialize the model and other components
@st.cache_resource
def load_model():
try:
# Define your model loading logic
print("Initializing model loading...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device selected: {device}")
# Load the VAE
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
print("VAE loaded successfully.")
# Load the tokenizer
print("Loading tokenizer...")
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
print("Tokenizer loaded successfully.")
# Load the text encoder
print("Loading text encoder...")
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
print("Text encoder loaded successfully.")
# Load the UNet model
print("Loading UNet...")
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
print("UNet loaded successfully.")
# Load the scheduler
print("Loading scheduler...")
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
print("Scheduler loaded successfully.")
# Initialize the pipeline
print("Initializing pipeline...")
pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
pipe.enable_attention_slicing()
print("Pipeline initialized successfully.")
return pipe
except Exception as e:
print(f"Error loading the model: {e}")
return None
pipe = load_model()
def generate_images(pipe, text_input=None, sketch=None):
# Generate images from text or sketch or both
images = []
try:
if pipe:
# Generate from text
if text_input:
print(f"Generating image from text: {text_input}")
images.append(pipe(prompt=[text_input]))
# Generate from sketch
if sketch:
print("Generating image from sketch.")
sketch_image = Image.open(sketch).convert("RGB")
images.append(pipe(sketch=sketch_image))
except Exception as e:
print(f"Error during image generation: {e}")
return images
# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")
# Input options
option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
sketch_file = None
text_input = None
# Get sketch input
if option in ["Sketch", "Both"]:
sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
# Get text input
if option in ["Text", "Both"]:
text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
# Generate button
if st.button("Generate"):
# Ensure the model is loaded
if pipe is None:
st.error("Model failed to load. Please restart the application.")
st.stop()
# Validate inputs
sketches = BytesIO(sketch_file.read()) if sketch_file else None
if option == "Sketch" and not sketch_file:
st.error("Please upload a sketch.")
elif option == "Text" and not text_input:
st.error("Please provide text input.")
elif option == "Both" and not (sketch_file or text_input):
st.error("Please provide both a sketch and a text prompt.")
else:
# Generate images
with st.spinner("Generating images..."):
images = generate_images(pipe, text_input=text_input, sketch=sketches)
# Display results
if images:
for i, img in enumerate(images):
if isinstance(img, torch.Tensor): # Convert tensor to image
img = img.squeeze().permute(1, 2, 0).cpu().numpy()
img = Image.fromarray((img * 255).astype("uint8"))
st.image(img, caption=f"Generated Image {i+1}")
else:
st.error("Failed to generate images. Please check the inputs or model configuration.")