import sys import threading import streamlit as st import numpy import torch import openshape import transformers from PIL import Image from huggingface_hub import HfFolder, snapshot_download from demo_support import retrieval @st.cache_resource def load_openclip(): sys.clip_move_lock = threading.Lock() clip_model, clip_prep = transformers.CLIPModel.from_pretrained( "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", low_cpu_mem_usage=True, torch_dtype=half, offload_state_dict=True ), transformers.CLIPProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") if torch.cuda.is_available(): with sys.clip_move_lock: clip_model.cuda() return clip_model, clip_prep def retrieval_filter_expand(key): with st.expander("Filters"): sim_th = st.slider("Similarity Threshold", 0.05, 0.5, 0.1, key=key + 'rtsimth') tag = st.text_input("Has Tag", "", key=key + 'rthastag') col1, col2 = st.columns(2) face_min = int(col1.text_input("Face Count Min", "0", key=key + 'rtfcmin')) face_max = int(col2.text_input("Face Count Max", "34985808", key=key + 'rtfcmax')) col1, col2 = st.columns(2) anim_min = int(col1.text_input("Animation Count Min", "0", key=key + 'rtacmin')) anim_max = int(col2.text_input("Animation Count Max", "563", key=key + 'rtacmax')) tag_n = not bool(tag.strip()) anim_n = not (anim_min > 0 or anim_max < 563) face_n = not (face_min > 0 or face_max < 34985808) filter_fn = lambda x: ( (anim_n or anim_min <= x['anims'] <= anim_max) and (face_n or face_min <= x['faces'] <= face_max) and (tag_n or tag in x['tags']) ) return sim_th, filter_fn def retrieval_results(results): st.caption("Click the link to view the 3D shape") for i in range(len(results) // 4): cols = st.columns(4) for j in range(4): idx = i * 4 + j if idx >= len(results): continue entry = results[idx] with cols[j]: ext_link = f"https://objaverse.allenai.org/explore/?query={entry['u']}" st.image(entry['img']) # st.markdown(f"[![thumbnail {entry['desc'].replace('\n', ' ')}]({entry['img']})]({ext_link})") # st.text(entry['name']) quote_name = entry['name'].replace('[', '\\[').replace(']', '\\]').replace('\n', ' ') st.markdown(f"[{quote_name}]({ext_link})") def demo_classification(): with st.form("clsform"): #load_data = misc_utils.input_3d_shape('cls') cats = st.text_input("Custom Categories (64 max, separated with comma)") cats = [a.strip() for a in cats.split(',')] if len(cats) > 64: st.error('Maximum 64 custom categories supported in the demo') return lvis_run = st.form_submit_button("Run Classification on LVIS Categories") custom_run = st.form_submit_button("Run Classification on Custom Categories") def demo_captioning(): with st.form("capform"): cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl') def demo_pc2img(): with st.form("sdform"): prompt = st.text_input("Prompt (Optional)", key='sdtprompt') def demo_retrieval(): with tab_pc: with st.form("rpcform"): k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc') pc = utils.load_3D_shape('rpcinput') if st.form_submit_button("Retrieve with Point Cloud"): prog.progress(0.49, "Computing Embeddings") with tab_img: with st.form("rimgform"): k = st.slider("Number of items to retrieve", 1, 100, 16, key='rimage') img = st.file_uploader("Upload an Image", key='rimageinput') if st.form_submit_button("Retrieve with Image"): prog.progress(0.49, "Computing Embeddings") with tab_text: with st.form("rtextform"): k = st.slider("Number of items to retrieve", 1, 100, 16, key='rtext') text = st.text_input("Input Text", key='rtextinput') sim_th, filter_fn = retrieval_filter_expand('text') if st.form_submit_button("Retrieve with Text"): prog.progress(0.49, "Computing Embeddings") device = clip_model.device tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device) enc = clip_model.get_text_features(**tn).float().cpu() prog.progress(0.7, "Running Retrieval") retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn)) prog.progress(1.0, "Idle") st.title("TripletMix Demo") st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.") prog = st.progress(0.0, "Idle") tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([ "Classification", "Retrieval w/ 3D", "Retrieval w/ Image", "Retrieval w/ Text", "Image Generation", "Captioning", ]) f32 = numpy.float32 half = torch.float16 if torch.cuda.is_available() else torch.bfloat16 clip_model, clip_prep = load_openclip() try: with tab_cls: demo_classification() with tab_cap: demo_captioning() with tab_sd: demo_pc2img() demo_retrieval() except Exception: import traceback st.error(traceback.format_exc().replace("\n", " \n"))