Spaces:
Runtime error
Runtime error
import re | |
import streamlit as st # HF spaces at v1.2.0 | |
from demo import load_model,generate,get_dataset,embed | |
# TODOs | |
# Add markdown short readme project intro | |
st.sidebar.image("assets/logo.png", use_column_width=True) | |
st.header("ButterflyGAN") | |
st.caption("This butterfly does not exist! ") | |
st.write("Demo prep still in progress!!") | |
def load_model_intocache(model_name): | |
# model_name='ceyda/butterfly_512_base' | |
gan = load_model(model_name) | |
return gan | |
def load_dataset(): | |
dataset=get_dataset() | |
return dataset | |
model_name='ceyda/butterfly_cropped_uniq1K_512' | |
model=load_model_intocache(model_name) | |
dataset=load_dataset() | |
screen = st.sidebar.radio("Pick a destination",["Make butterflies","Take a latent walk", "See the data mosaic"]) | |
if screen == "Make butterflies": | |
if 'ims' not in st.session_state: | |
st.session_state['ims'] = None | |
ims=st.session_state["ims"] | |
batch_size=4 #generate 4 butterflies | |
def run(): | |
with st.spinner("Generating..."): | |
ims=generate(model,batch_size) | |
st.session_state['ims'] = ims | |
runb=st.button("Generate", on_click=run) | |
if ims is not None: | |
cols=st.columns(batch_size) | |
picks=[False]*batch_size | |
for i,im in enumerate(ims): | |
cols[i].image(im) | |
picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i)) | |
# if picks[i]: | |
# scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5) | |
# for r in retrieved_examples["image"]: | |
# st.image(r) | |
if any(picks): | |
# st.write("Nearest butterflies:") | |
for i,pick in enumerate(picks): | |
if pick: | |
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5) | |
for r in retrieved_examples["image"]: | |
cols[i].image(r) | |
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}") | |
elif screen == "Take a latent walk": | |
st.write("Take a latent walk") | |
elif screen == "Input data mosaic": | |
st.markdown("Todo add explanation about data") | |
st.image("assets/training_data_lowres.png") | |
# footer stuff | |
st.sidebar.info(f"Model {model_name} is loaded") | |