File size: 1,988 Bytes
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""

from typing import Literal, Optional, Tuple

import attrs


@attrs.define(slots=False)
class ModelConfig:
    # input
    use_salient_point: bool = True

    # random cutoff during training
    cutoff_fps_point: Tuple[int, ...] = (256, 512, 512, 512, 1024, 1024, 2048)
    cutoff_fps_salient_point: Tuple[int, ...] = (0, 0, 256, 512, 512, 1024, 2048)
    cutoff_fps_prob: Tuple[float, ...] = (0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.2)  # sum to 1.0

    # backbone transformer
    num_enc_layers: int = 0
    hidden_dim: int = 1024
    num_heads: int = 16
    num_dec_layers: int = 24
    dec_hidden_dim: int = 1024
    dec_num_heads: int = 16
    qknorm: bool = True
    qknorm_type: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"  # type of qknorm
    salient_attn_mode: Literal["dual_shared", "single", "dual"] = "dual"

    # query decoder
    fourier_version: Literal["v1", "v2", "v3"] = "v3"
    point_fourier_dim: int = 48  # must be divisible by 6 (sin/cos, x/y/z)
    query_hidden_dim: int = 1024
    query_num_heads: int = 16
    use_flash_query: bool = False

    # latent code
    latent_size: int = 4096  # == num_fps_point + num_fps_salient_point
    latent_dim: int = 64

    # loss
    use_ae: bool = False  # if true, variance will be ignored, and kl_weight is used as a L2 norm weight
    kl_weight: float = 1e-3

    # init weights from a pretrained checkpoint
    pretrain_path: Optional[str] = None