Spaces:
Runtime error
Runtime error
File size: 2,416 Bytes
e83adaa dc6157e e83adaa d1ca92a e83adaa be419c1 e83adaa be419c1 e83adaa fc3e1a1 e83adaa 1e6b2bf e83adaa 2f8a733 e83adaa 2f8a733 e83adaa 7f9ce57 e83adaa 7f9ce57 e83adaa 2f8a733 e83adaa d1ca92a e83adaa 2f8a733 e83adaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import diffusers
import torch
import os
import time
import accelerate
import streamlit as st
from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'
def load_pipeline_w_lora() :
# Load pretrained unet from huggingface
unet = UNet2DConditionModel.from_pretrained(
MODEL_REPO,
subfolder="unet",
revision=None
)
# Load LoRa attn layer weights to unet attn layers
unet.load_attn_procs(LoRa_DIR)
print('loaded')
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_REPO,
unet=unet,
revision=None,
torch_dtype=torch.float32,
)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def elapsed_time(fn, *args):
start = time.time()
output = fn(*args)
end = time.time()
elapsed = f'{end - start:.2f}'
return elapsed, output
def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
pipeline = pipeline.to(device)
st.write(f"Model is loaded in {elapsed} seconds!")
st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>")
prompt = st.text_input(
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
num_images = int(st.text_input(label="Number of images to generate"))
seed = int(st.text_input(label="Seed for images"))
if prompt and num_images and seed:
st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...")
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for img_idx in stqdm(range(num_images)):
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(generated_image)
st.write("Done!")
st.image(images, width=150, caption=f"Generated Images with <{prompt}>")
if __name__ == '__main__':
main()
|