import streamlit as st st.title("TripletMix Demo") st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.") prog = st.progress(0.0, "Idle") tab_cls, tab_img, tab_text, tab_pc, tab_sd, tab_cap = st.tabs([ "Classification", "Retrieval w/ Image", "Retrieval w/ Text", "Retrieval w/ 3D", "Image Generation", "Captioning", ]) def demo_classification(): with st.form("clsform"): #load_data = misc_utils.input_3d_shape('cls') cats = st.text_input("Custom Categories (64 max, separated with comma)") cats = [a.strip() for a in cats.split(',')] if len(cats) > 64: st.error('Maximum 64 custom categories supported in the demo') return lvis_run = st.form_submit_button("Run Classification on LVIS Categories") custom_run = st.form_submit_button("Run Classification on Custom Categories") """ if lvis_run or auto_submit("clsauto"): pc = load_data(prog) col2 = misc_utils.render_pc(pc) prog.progress(0.5, "Running Classification") pred = classification.pred_lvis_sims(model_g14, pc) with col2: for i, (cat, sim) in zip(range(5), pred.items()): st.text(cat) st.caption("Similarity %.4f" % sim) prog.progress(1.0, "Idle") if custom_run: pc = load_data(prog) col2 = misc_utils.render_pc(pc) prog.progress(0.5, "Computing Category Embeddings") device = clip_model.device tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device) feats = clip_model.get_text_features(**tn).float().cpu() prog.progress(0.5, "Running Classification") pred = classification.pred_custom_sims(model_g14, pc, cats, feats) with col2: for i, (cat, sim) in zip(range(5), pred.items()): st.text(cat) st.caption("Similarity %.4f" % sim) prog.progress(1.0, "Idle") """ """ if image_examples(samples_index.classification, 3, example_text="Examples (Choose one of the following 3D shapes)"): queue_auto_submit("clsauto") """ try: with tab_cls: demo_classification() """ with tab_cap: demo_captioning() with tab_sd: demo_pc2img() demo_retrieval() """ except Exception: import traceback st.error(traceback.format_exc().replace("\n", " \n"))