File size: 2,607 Bytes
9145aca
 
64fa430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st

st.title("TripletMix Demo")
st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
prog = st.progress(0.0, "Idle")
tab_cls, tab_img, tab_text, tab_pc, tab_sd, tab_cap = st.tabs([
    "Classification",
    "Retrieval w/ Image",
    "Retrieval w/ Text",
    "Retrieval w/ 3D",
    "Image Generation",
    "Captioning",
])


def demo_classification():
    with st.form("clsform"):
        #load_data = misc_utils.input_3d_shape('cls')
        cats = st.text_input("Custom Categories (64 max, separated with comma)")
        cats = [a.strip() for a in cats.split(',')]
        if len(cats) > 64:
            st.error('Maximum 64 custom categories supported in the demo')
            return
        lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
        custom_run = st.form_submit_button("Run Classification on Custom Categories")
        """
        if lvis_run or auto_submit("clsauto"):
            pc = load_data(prog)
            col2 = misc_utils.render_pc(pc)
            prog.progress(0.5, "Running Classification")
            pred = classification.pred_lvis_sims(model_g14, pc)
            with col2:
                for i, (cat, sim) in zip(range(5), pred.items()):
                    st.text(cat)
                    st.caption("Similarity %.4f" % sim)
            prog.progress(1.0, "Idle")
        if custom_run:
            pc = load_data(prog)
            col2 = misc_utils.render_pc(pc)
            prog.progress(0.5, "Computing Category Embeddings")
            device = clip_model.device
            tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device)
            feats = clip_model.get_text_features(**tn).float().cpu()
            prog.progress(0.5, "Running Classification")
            pred = classification.pred_custom_sims(model_g14, pc, cats, feats)
            with col2:
                for i, (cat, sim) in zip(range(5), pred.items()):
                    st.text(cat)
                    st.caption("Similarity %.4f" % sim)
            prog.progress(1.0, "Idle")
        """
    """
    if image_examples(samples_index.classification, 3, example_text="Examples (Choose one of the following 3D shapes)"):
        queue_auto_submit("clsauto")
    """






try:
    with tab_cls:
        demo_classification()
    """
    with tab_cap:
        demo_captioning()
    with tab_sd:
        demo_pc2img()
    demo_retrieval()
    """
except Exception:
    import traceback
    st.error(traceback.format_exc().replace("\n", "  \n"))