File size: 6,051 Bytes
18bb538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os,sys
sys.path.append("..")
from configs.config_utils import CONFIG
from models import get_model
import torch
import numpy as np
import open3d as o3d
import timm
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from simple_dataset import InTheWild_Dataset,classname_remap,classname_map
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import mcubes
import trimesh
from torch.utils.data import DataLoader

def image_transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073),
                 (0.26862954, 0.26130258, 0.27577711)),
    ])

MAX_IMG_LENGTH=5 #take up to 5 images as inputs

ae_paths={
        "chair":"../checkpoint/ae/chair/best-checkpoint.pth",
        "table":"../checkpoint/ae/table/best-checkpoint.pth",
        "cabinet":"../checkpoint/ae/cabinet/best-checkpoint.pth",
        "shelf":"../checkpoint/ae/shelf/best-checkpoint.pth",
        "sofa":"../checkpoint/ae/sofa/best-checkpoint.pth",
        "bed":"../checkpoint/ae/bed/best-checkpoint.pth"
        }
dm_paths={
        "chair":"../checkpoint/finetune_dm/chair/best-checkpoint.pth",
        "table":"../checkpoint/finetune_dm/table/best-checkpoint.pth",
        "cabinet":"../checkpoint/finetune_dm/cabinet/best-checkpoint.pth",
        "shelf":"../checkpoint/finetune_dm/shelf/best-checkpoint.pth",
        "sofa":"../checkpoint/finetune_dm/sofa/best-checkpoint.pth",
        "bed":"../checkpoint/finetune_dm/bed/best-checkpoint.pth"
        }

def inference(ae_model,dm_model,data_batch,device,reso=256):
    density = reso
    gap = 2.2 / density
    x = np.linspace(-1.1, 1.1, int(density + 1))
    y = np.linspace(-1.1, 1.1, int(density + 1))
    z = np.linspace(-1.1, 1.1, int(density + 1))
    xv, yv, zv = np.meshgrid(x, y, z, indexing='ij')
    grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,
                                                                                                            non_blocking=True)
    with torch.no_grad():
        sample_input = dm_model.prepare_sample_data(data_batch)
        sampled_array = dm_model.sample(sample_input, num_steps=36).float()
        sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")

    model_ids = data_batch['model_id']
    tran_mats = data_batch['tran_mat']

    output_meshes={}

    for j in range(sampled_array.shape[0]):
        grid_list = torch.split(grid, 128 ** 3, dim=1)
        output_list = []
        with torch.no_grad():
            for sub_grid in grid_list:
                output_list.append(ae_model.decode(sampled_array[j:j + 1], sub_grid))
        output = torch.cat(output_list, dim=1)
        logits = output[j].detach()

        volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
        verts, faces = mcubes.marching_cubes(volume, 0)

        verts *= gap
        verts -= 1.1

        tran_mat = tran_mats[j].numpy()
        verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1)
        verts_inwrd = np.dot(verts_homo, tran_mat.T)[:, 0:3]
        m_inwrd = trimesh.Trimesh(verts_inwrd, faces[:, ::-1]) #transform the mesh into world coordinate

        output_meshes[model_ids[j]]=m_inwrd
    return output_meshes

if __name__=="__main__":
    import argparse
    parser=argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="../example_process_data")
    parser.add_argument('--scene_id', default="all", type=str)
    parser.add_argument("--save_dir", type=str,default="../example_output_data")
    args = parser.parse_args()

    config_path="../configs/finetune_triplane_diffusion.yaml"
    config=CONFIG(config_path).config

    '''creating save folder'''
    save_folder=os.path.join(args.save_dir,args.scene_id)
    os.makedirs(save_folder,exist_ok=True)

    '''prepare model'''
    device=torch.device("cuda")
    ae_config=config['model']['ae']
    dm_config=config['model']['dm']
    dm_model=get_model(dm_config).to(device)
    ae_model=get_model(ae_config).to(device)
    dm_model.eval()
    ae_model.eval()

    '''preparing data'''
    '''find out how many classes are there in the whole scene'''
    images_folder=os.path.join(args.data_dir,args.scene_id,"6_images")
    object_id_list=os.listdir(images_folder)
    object_class_list=[item.split("_")[0] for item in object_id_list]
    all_object_class=list(set(object_class_list))

    exist_super_categories=[]
    for object_class in all_object_class:
        if object_class not in classname_remap:
            continue
        else:
            exist_super_categories.append(classname_remap[object_class]) #find which category specific models should be employed
    exist_super_categories=list(set(exist_super_categories))
    for super_category in exist_super_categories:
        print("processing %s"%(super_category))
        ae_ckpt=torch.load(ae_paths[super_category],map_location="cpu")["model"]
        dm_ckpt=torch.load(dm_paths[super_category],map_location="cpu")["model"]
        ae_model.load_state_dict(ae_ckpt)
        dm_model.load_state_dict(dm_ckpt)
        dataset = InTheWild_Dataset(data_dir=args.data_dir, scene_id=args.scene_id, category=super_category, max_n_imgs=5)
        dataloader=DataLoader(
            dataset=dataset,
            num_workers=1,
            batch_size=1,
            shuffle=False
        )
        for data_batch in dataloader:
            output_meshes=inference(ae_model,dm_model,data_batch,device)
            #print(output_meshes)
            for model_id in output_meshes:
                mesh=output_meshes[model_id]
                save_path=os.path.join(save_folder,model_id+".ply")
                print("saving to %s"%(save_path))
                mesh.export(save_path)