fyp-deploy / app.py
Mairaaa's picture
Update app.py
f26e6e9 verified
raw
history blame
3.93 kB
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
# Initialize the model and other components
@st.cache_resource
def load_model():
try:
# Define your model loading logic
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
print("VAE model loaded successfully.")
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
# Initialize the pipeline
pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
pipe.enable_attention_slicing() # Enable memory-efficient inference
return pipe
except Exception as e:
print(f"Error loading the model: {e}")
return None
pipe = load_model()
def generate_images(pipe, text_input=None, sketch=None):
# Generate images from text or sketch or both
images = []
try:
if pipe:
# Generate from text
if text_input:
print(f"Generating image from text: {text_input}")
images.append(pipe(prompt=[text_input]))
# Generate from sketch
if sketch:
print("Generating image from sketch.")
sketch_image = Image.open(sketch).convert("RGB")
images.append(pipe(sketch=sketch_image))
except Exception as e:
print(f"Error during image generation: {e}")
return images
# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")
# Input options
option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
sketch_file = None
text_input = None
# Get sketch input
if option in ["Sketch", "Both"]:
sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
# Get text input
if option in ["Text", "Both"]:
text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
# Generate button
if st.button("Generate"):
# Ensure the model is loaded
if pipe is None:
st.error("Model failed to load. Please restart the application.")
st.stop()
# Validate inputs
sketches = BytesIO(sketch_file.read()) if sketch_file else None
if option == "Sketch" and not sketch_file:
st.error("Please upload a sketch.")
elif option == "Text" and not text_input:
st.error("Please provide text input.")
elif option == "Both" and not (sketch_file or text_input):
st.error("Please provide both a sketch and a text prompt.")
else:
# Generate images
with st.spinner("Generating images..."):
images = generate_images(pipe, text_input=text_input, sketch=sketches)
# Display results
if images:
for i, img in enumerate(images):
if isinstance(img, torch.Tensor): # Convert tensor to image
img = img.squeeze().permute(1, 2, 0).cpu().numpy()
img = Image.fromarray((img * 255).astype("uint8"))
st.image(img, caption=f"Generated Image {i+1}")
else:
st.error("Failed to generate images. Please check the inputs or model configuration.")