File size: 2,382 Bytes
9fbe234
b0b9e1f
9fbe234
b0b9e1f
 
 
cb5f8d1
 
 
 
 
b0b9e1f
 
cb5f8d1
b0b9e1f
 
 
 
 
 
 
9fbe234
 
 
 
 
b0b9e1f
 
9fbe234
b0b9e1f
cb5f8d1
 
 
 
9fbe234
cb5f8d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b9e1f
 
cb5f8d1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import re
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed

# TODOs
# Add markdown short readme project intro


st.sidebar.image("assets/logo.png", use_column_width=True)
st.header("ButterflyGAN")
st.caption("This butterfly does not exist! ")
st.write("Demo prep still in progress!!")


@st.experimental_singleton
def load_model_intocache(model_name):
    
    # model_name='ceyda/butterfly_512_base'
    gan = load_model(model_name)
    return gan

@st.experimental_singleton
def load_dataset():
    dataset=get_dataset()
    return dataset

model_name='ceyda/butterfly_cropped_uniq1K_512'
model=load_model_intocache(model_name)
dataset=load_dataset()

screen = st.sidebar.radio("Pick a destination",["Make butterflies","Take a latent walk", "See the data mosaic"])

if screen == "Make butterflies":

    
    
    if 'ims' not in st.session_state:
        st.session_state['ims'] = None

    ims=st.session_state["ims"]
    batch_size=4 #generate 4 butterflies 
    def run():
        with st.spinner("Generating..."):
            ims=generate(model,batch_size)
            st.session_state['ims'] = ims

    runb=st.button("Generate", on_click=run)
    if ims is not None:
        cols=st.columns(batch_size)
        picks=[False]*batch_size
        for i,im in enumerate(ims):
            cols[i].image(im)
            picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i))
            # if picks[i]:
            #     scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
            #     for r in retrieved_examples["image"]:
            #         st.image(r)
        
        if any(picks):
            # st.write("Nearest butterflies:")
            for i,pick in enumerate(picks):
                if pick:
                    scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
                    for r in retrieved_examples["image"]:
                        cols[i].image(r)

    st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")

elif screen == "Take a latent walk":
    st.write("Take a latent walk")

elif screen == "Input data mosaic":
    st.markdown("Todo add explanation about data")
    st.image("assets/training_data_lowres.png")


# footer stuff
st.sidebar.info(f"Model {model_name} is loaded")