File size: 1,871 Bytes
9145aca
 
64fa430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b3975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64fa430
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import streamlit as st

st.title("TripletMix Demo")
st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
prog = st.progress(0.0, "Idle")
tab_cls, tab_img, tab_text, tab_pc, tab_sd, tab_cap = st.tabs([
    "Classification",
    "Retrieval w/ Image",
    "Retrieval w/ Text",
    "Retrieval w/ 3D",
    "Image Generation",
    "Captioning",
])


def demo_classification():
    with st.form("clsform"):
        #load_data = misc_utils.input_3d_shape('cls')
        cats = st.text_input("Custom Categories (64 max, separated with comma)")
        cats = [a.strip() for a in cats.split(',')]
        if len(cats) > 64:
            st.error('Maximum 64 custom categories supported in the demo')
            return
        lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
        custom_run = st.form_submit_button("Run Classification on Custom Categories")


def demo_captioning():
    with st.form("capform"):
        cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')

def demo_pc2img():
    with st.form("sdform"):
        prompt = st.text_input("Prompt (Optional)", key='sdtprompt')

def demo_retrieval():
    with tab_pc:
        k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rpc')
    with tab_img:
        with st.form("rimgform"):
            k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rimage')
    with tab_text:
        with st.form("rtextform"):
            k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
            text = st.text_input("Input Text", key="inputrtext")
try:
    with tab_cls:
        demo_classification()
    with tab_cap:
        demo_captioning()
    with tab_sd:
        demo_pc2img()
    demo_retrieval()
except Exception:
    import traceback
    st.error(traceback.format_exc().replace("\n", "  \n"))