fyp-deploy / app.py
Mairaaa's picture
Update app.py
40cbb76 verified
raw
history blame
2.55 kB
import streamlit as st
import torch
from PIL import Image
from io import BytesIO
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from src.mgd_pipelines.mgd_pipe import MGDPipe
# Initialize the model and other components
@st.cache_resource
def load_model():
# Define your model loading logic
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=scheduler,
).to(device)
return pipe
pipe = load_model()
def generate_images(pipe, text_input=None, sketch=None):
# Generate images from text or sketch or both
images = []
if text_input:
prompt = [text_input]
images.extend(pipe(prompt=prompt))
if sketch:
sketch_image = Image.open(sketch).convert("RGB")
images.extend(pipe(sketch=sketch_image))
return images
# Streamlit UI
st.title("Sketch & Text-based Image Generation")
st.write("Generate images based on rough sketches, text input, or both.")
option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
if option in ["Sketch", "Both"]:
sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
if option in ["Text", "Both"]:
text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
if st.button("Generate"):
if option == "Sketch" and not sketch_file:
st.error("Please upload a sketch.")
elif option == "Text" and not text_input:
st.error("Please provide text input.")
else:
# Generate images based on user input
with st.spinner("Generating images..."):
sketches = BytesIO(sketch_file.read()) if sketch_file else None
images = generate_images(pipe, text_input=text_input, sketch=sketches)
# Display results
for i, img in enumerate(images):
st.image(img, caption=f"Generated Image {i+1}")