treasuraid's picture
Update app.py
0185ba1
import diffusers
import torch
import os
import time
import accelerate
import streamlit as st
from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'
# @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy()})
def load_pipeline_w_lora() :
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_REPO,
revision=None,
torch_dtype=torch.float32,
)
# Load LoRa attn layer weights to unet attn layers
print('LoRa layers loading...')
pipeline.unet.load_attn_procs(LoRa_DIR)
print('LoRa layers loaded')
pipeline.set_progress_bar_config(disable=True)
return pipeline
def elapsed_time(fn, *args):
start = time.time()
output = fn(*args)
end = time.time()
elapsed = f'{end - start:.2f}'
return elapsed, output
st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
st.write("Loading models...")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
st.write(f"Model is loaded in {elapsed} seconds!")
pipeline = pipeline.to(device)
sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>")
with st.form(key="information", clear_on_submit=True):
prompt = st.text_input(
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
num_images = st.number_input(label="Number of images to generate", min_value=1, max_value=10)
seed = st.number_input(label="Seed for images", min_value=1, max_value=10000)
submitted = st.form_submit_button(label="Submit")
if submitted :
st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...")
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for img_idx in stqdm(range(num_images)):
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(generated_image)
st.write("Done!")
st.image(images, width=300, caption=[f"Generated Images with <{prompt}>" for i in range(len(images))])