Spaces:
Sleeping
Sleeping
import sys | |
import threading | |
import streamlit as st | |
import numpy | |
import torch | |
import openshape | |
import transformers | |
from PIL import Image | |
from huggingface_hub import HfFolder, snapshot_download | |
from demo_support import retrieval | |
def load_openclip(): | |
sys.clip_move_lock = threading.Lock() | |
clip_model, clip_prep = transformers.CLIPModel.from_pretrained( | |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", | |
low_cpu_mem_usage=True, torch_dtype=half, | |
offload_state_dict=True | |
), transformers.CLIPProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | |
if torch.cuda.is_available(): | |
with sys.clip_move_lock: | |
clip_model.cuda() | |
return clip_model, clip_prep | |
def retrieval_filter_expand(key): | |
with st.expander("Filters"): | |
sim_th = st.slider("Similarity Threshold", 0.05, 0.5, 0.1, key=key + 'rtsimth') | |
tag = st.text_input("Has Tag", "", key=key + 'rthastag') | |
col1, col2 = st.columns(2) | |
face_min = int(col1.text_input("Face Count Min", "0", key=key + 'rtfcmin')) | |
face_max = int(col2.text_input("Face Count Max", "34985808", key=key + 'rtfcmax')) | |
col1, col2 = st.columns(2) | |
anim_min = int(col1.text_input("Animation Count Min", "0", key=key + 'rtacmin')) | |
anim_max = int(col2.text_input("Animation Count Max", "563", key=key + 'rtacmax')) | |
tag_n = not bool(tag.strip()) | |
anim_n = not (anim_min > 0 or anim_max < 563) | |
face_n = not (face_min > 0 or face_max < 34985808) | |
filter_fn = lambda x: ( | |
(anim_n or anim_min <= x['anims'] <= anim_max) | |
and (face_n or face_min <= x['faces'] <= face_max) | |
and (tag_n or tag in x['tags']) | |
) | |
return sim_th, filter_fn | |
def retrieval_results(results): | |
st.caption("Click the link to view the 3D shape") | |
for i in range(len(results) // 4): | |
cols = st.columns(4) | |
for j in range(4): | |
idx = i * 4 + j | |
if idx >= len(results): | |
continue | |
entry = results[idx] | |
with cols[j]: | |
ext_link = f"https://objaverse.allenai.org/explore/?query={entry['u']}" | |
st.image(entry['img']) | |
# st.markdown(f"[![thumbnail {entry['desc'].replace('\n', ' ')}]({entry['img']})]({ext_link})") | |
# st.text(entry['name']) | |
quote_name = entry['name'].replace('[', '\\[').replace(']', '\\]').replace('\n', ' ') | |
st.markdown(f"[{quote_name}]({ext_link})") | |
def demo_classification(): | |
with st.form("clsform"): | |
#load_data = misc_utils.input_3d_shape('cls') | |
cats = st.text_input("Custom Categories (64 max, separated with comma)") | |
cats = [a.strip() for a in cats.split(',')] | |
if len(cats) > 64: | |
st.error('Maximum 64 custom categories supported in the demo') | |
return | |
lvis_run = st.form_submit_button("Run Classification on LVIS Categories") | |
custom_run = st.form_submit_button("Run Classification on Custom Categories") | |
def demo_captioning(): | |
with st.form("capform"): | |
cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl') | |
def demo_pc2img(): | |
with st.form("sdform"): | |
prompt = st.text_input("Prompt (Optional)", key='sdtprompt') | |
def demo_retrieval(): | |
with tab_pc: | |
with st.form("rpcform"): | |
k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc') | |
pc = utils.load_3D_shape('rpcinput') | |
if st.form_submit_button("Retrieve with Point Cloud"): | |
prog.progress(0.49, "Computing Embeddings") | |
with tab_img: | |
with st.form("rimgform"): | |
k = st.slider("Number of items to retrieve", 1, 100, 16, key='rimage') | |
img = st.file_uploader("Upload an Image", key='rimageinput') | |
if st.form_submit_button("Retrieve with Image"): | |
prog.progress(0.49, "Computing Embeddings") | |
with tab_text: | |
with st.form("rtextform"): | |
k = st.slider("Number of items to retrieve", 1, 100, 16, key='rtext') | |
text = st.text_input("Input Text", key='rtextinput') | |
sim_th, filter_fn = retrieval_filter_expand('text') | |
if st.form_submit_button("Retrieve with Text"): | |
prog.progress(0.49, "Computing Embeddings") | |
device = clip_model.device | |
tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device) | |
enc = clip_model.get_text_features(**tn).float().cpu() | |
prog.progress(0.7, "Running Retrieval") | |
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn)) | |
prog.progress(1.0, "Idle") | |
st.title("TripletMix Demo") | |
st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.") | |
prog = st.progress(0.0, "Idle") | |
tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([ | |
"Classification", | |
"Retrieval w/ 3D", | |
"Retrieval w/ Image", | |
"Retrieval w/ Text", | |
"Image Generation", | |
"Captioning", | |
]) | |
f32 = numpy.float32 | |
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16 | |
clip_model, clip_prep = load_openclip() | |
try: | |
with tab_cls: | |
demo_classification() | |
with tab_cap: | |
demo_captioning() | |
with tab_sd: | |
demo_pc2img() | |
demo_retrieval() | |
except Exception: | |
import traceback | |
st.error(traceback.format_exc().replace("\n", " \n")) | |