Spaces:
Runtime error
Runtime error
File size: 2,382 Bytes
9fbe234 b0b9e1f 9fbe234 b0b9e1f cb5f8d1 b0b9e1f cb5f8d1 b0b9e1f 9fbe234 b0b9e1f 9fbe234 b0b9e1f cb5f8d1 9fbe234 cb5f8d1 b0b9e1f cb5f8d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import re
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed
# TODOs
# Add markdown short readme project intro
st.sidebar.image("assets/logo.png", use_column_width=True)
st.header("ButterflyGAN")
st.caption("This butterfly does not exist! ")
st.write("Demo prep still in progress!!")
@st.experimental_singleton
def load_model_intocache(model_name):
# model_name='ceyda/butterfly_512_base'
gan = load_model(model_name)
return gan
@st.experimental_singleton
def load_dataset():
dataset=get_dataset()
return dataset
model_name='ceyda/butterfly_cropped_uniq1K_512'
model=load_model_intocache(model_name)
dataset=load_dataset()
screen = st.sidebar.radio("Pick a destination",["Make butterflies","Take a latent walk", "See the data mosaic"])
if screen == "Make butterflies":
if 'ims' not in st.session_state:
st.session_state['ims'] = None
ims=st.session_state["ims"]
batch_size=4 #generate 4 butterflies
def run():
with st.spinner("Generating..."):
ims=generate(model,batch_size)
st.session_state['ims'] = ims
runb=st.button("Generate", on_click=run)
if ims is not None:
cols=st.columns(batch_size)
picks=[False]*batch_size
for i,im in enumerate(ims):
cols[i].image(im)
picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i))
# if picks[i]:
# scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
# for r in retrieved_examples["image"]:
# st.image(r)
if any(picks):
# st.write("Nearest butterflies:")
for i,pick in enumerate(picks):
if pick:
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
for r in retrieved_examples["image"]:
cols[i].image(r)
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
elif screen == "Take a latent walk":
st.write("Take a latent walk")
elif screen == "Input data mosaic":
st.markdown("Todo add explanation about data")
st.image("assets/training_data_lowres.png")
# footer stuff
st.sidebar.info(f"Model {model_name} is loaded")
|