File size: 5,541 Bytes
f15a1cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import gradio as gr
import numpy as np
import torch
import functools
from datasets import load_dataset
from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder

# os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000'
# os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000'

MAX_BATCH_SIZE = 16
MAX_QUEUE_SIZE = 10
MAX_K_RETRIEVAL = 20
cache_dir = "./.cache"

encoder = Uni3dEmbeddingEncoder(cache_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
source_id_list = torch.load("data/source_id_list.pt")
source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)}
dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir)

@functools.lru_cache()
def get_embedding(option, modality, angle=None):
    save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt'
    if os.path.exists(save_path):
        return torch.load(save_path)
    else:
        return gr.Error(f"Embedding file not found: {save_path}")

def predict(xb, xq, top_k):
    xb = xb.to(xq.device)
    sim = xq @ xb.T # (nq, nb)
    _, indices = sim.topk(k=top_k, largest=True)
    return indices

def get_image(index):
    return dataset[index]["image"]

def retrieve_3D_models(textual_query, top_k, modality_list):
    if textual_query == "":
        raise gr.Error("Please enter a textual query")
    if len(textual_query.split()) > 20:
        gr.Warning("Retrieval result may be inaccurate due to long textual query")
    if len(modality_list) == 0:
        raise gr.Error("Please select at least one modality")
    
    def _retrieve_3D_models(query, top_k, modals:list):
        option = "uni3d"
        op = "add"
        is_text = True if "text" in modals else False
        is_3D = True if "3D" in modals else False
        if is_text:
            modals.remove("text")
        if is_3D:
            modals.remove("3D")
        angles = modals

        # get base embeddings
        embeddings = []
        if is_text:
            embeddings.append(get_embedding(option, "text"))
        if len(angles) > 0:
            for angle in angles:
                embeddings.append(get_embedding(option, "image", angle=angle))
        if is_3D:
            embeddings.append(get_embedding(option, "3D"))
            
        ## fuse base embeddings
        if len(embeddings) > 1:
            if op == "concat":
                embeddings = torch.cat(embeddings, dim=-1)
            elif op == "add":
                embeddings = sum(embeddings)
            else:
                raise ValueError(f"Unsupported operation: {op}")
            embeddings /= embeddings.norm(dim=-1, keepdim=True)
        else:
            embeddings = embeddings[0]

        # encode query embeddings
        xq = encoder.encode_query(query)
        if op == "concat":
            xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb
            xq /= xq.norm(dim=-1, keepdim=True)
        
        pred_ind_list = predict(embeddings, xq, top_k)
        return pred_ind_list[0].cpu().tolist() # we have only one query

    indices = _retrieve_3D_models(textual_query, top_k, modality_list)
    return [get_image(index) for index in indices]

def launch():
    with gr.Blocks() as demo:
        with gr.Row():
            textual_query = gr.Textbox(label="Textual Query", autofocus=True,
                                       placeholder="A chair with a wooden frame and a cushioned seat")
            modality_list = gr.CheckboxGroup(label="Modality List", value=[],
                                             choices=["text", "front", "back", "left", "right", "above", 
                                                      "below", "diag_above", "diag_below", "3D"])
        with gr.Row():
            top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result", 
                              value=5, scale=2)
            run = gr.Button("Search", scale=1)
            clear_button = gr.ClearButton(scale=1)
        with gr.Row():
            output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil")
        run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output, 
                #   batch=True, max_batch_size=MAX_BATCH_SIZE
                  )
        clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output])
        examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
                                         ["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
                                         ["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]],
                            inputs=[textual_query, top_k, modality_list],
                            # cache_examples=True,
                            outputs=output,
                            fn=retrieve_3D_models)
        
    demo.queue(max_size=10)

    # os.environ.pop('HTTP_PROXY')
    # os.environ.pop('HTTPS_PROXY')

    demo.launch(server_name='0.0.0.0')

if __name__ == "__main__":
    launch()
    # print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"])))