File size: 3,210 Bytes
8c02843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import argparse
import torch
import numpy as np
import pytorch_lightning as pl
from omegaconf import OmegaConf

from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
from StructDiffusion.language.tokenizer import Tokenizer
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
from StructDiffusion.diffusion.sampler import Sampler
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs


def main(args, cfg):

    pl.seed_everything(args.eval_random_seed)

    device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

    checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
    checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)

    if args.eval_mode == "infer":

        tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
        # override ignore_rgb for visualization
        cfg.DATASET.ignore_rgb = False
        dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)

        sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device)

        data_idxs = np.random.permutation(len(dataset))
        for di in data_idxs:
            raw_datum = dataset.get_raw_data(di)
            print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
            datum = dataset.convert_to_tensors(raw_datum, tokenizer)
            batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)

            num_poses = datum["goal_poses"].shape[0]
            xs = sampler.sample(batch, num_poses)

            struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
            new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
            visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="infer")
    parser.add_argument("--base_config_file", help='base config yaml file',
                        default='../configs/base.yaml',
                        type=str)
    parser.add_argument("--config_file", help='config yaml file',
                        default='../configs/conditional_pose_diffusion.yaml',
                        type=str)
    parser.add_argument("--checkpoint_id",
                        default="ConditionalPoseDiffusion",
                        type=str)
    parser.add_argument("--eval_mode",
                        default="infer",
                        type=str)
    parser.add_argument("--eval_random_seed",
                        default=42,
                        type=int)
    parser.add_argument("--num_samples",
                        default=10,
                        type=int)
    args = parser.parse_args()

    base_cfg = OmegaConf.load(args.base_config_file)
    cfg = OmegaConf.load(args.config_file)
    cfg = OmegaConf.merge(base_cfg, cfg)

    main(args, cfg)