File size: 4,339 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
from dataclasses import dataclass

from genie.factorization_utils import nth_root
from typing import List


@dataclass
class GenieConfig:
    num_layers: int
    num_heads: int
    d_model: int
    T: int = 12  # temporal sequence length
    S: int = 256  # spatial sequence length, e.g. 256 for 16x16
    image_vocab_size: int = 262144  # image_vocab_size: number of distinct image tokens;
    # actual model vocab size is larger to include special (e.g. mask) tokens.
    use_mup: bool = False
    dataloader_apply_mask: bool = True # apply mask in dataloader
    dataloader_apply_corruption: bool = True
    dataloader_mask_ratio_min: float = 0.2
    drop_action_ratio: float = 0.0 # for datasets
    arch: str = "STTransformerDecoder"
    random_dummy_action: bool = True # for model

    # Factorization for large vocabs (e.g. Open-MAGVIT2)
    num_factored_vocabs: int = 1
    factored_vocab_size: int = None

    # MaskGIT training (all arbitrary numbers)
    max_corrupt_rate: float = 0.2  # Corrupt all tokens, uniform between [0, max_corrupt_rate]
    # Case 1: MLM training.
    # Case 2: Not standard MLM, `non_mlm`. Some earlier frames are left unmasked, as in Copilot4D.
    non_mlm_ratio: float = 0.2
    num_prompt_frames: int = 4

    # action related
    init_actions: bool = False
    d_action: int = 28 # action dimensions
    use_actions: bool = True
    action_domains: List[str] = None
    d_actions: List[int] = None
    action_stats: List[List[float]] = None  # TODO: is this actually three nested lists?
    action_network: str = "mlp"
    shared_action_mlps: bool = True
    action_contrastive_loss: bool = False
    jointly_predict_actions: bool = False # jointly predict actions
    jointly_predict_states: bool = True # jointly predict states
    action_token_size: int = 64 # images are 16x16
    label_drop_prob: float = 0.5 # the drop ratio for action tokens
    action_loss_weight: float = 0.5 # weight for action loss

    # Attention
    qkv_bias: bool = False
    proj_bias: bool = True
    attn_drop: float = 0.0
    qk_norm: bool = True

    # MLP
    mlp_ratio: float = 4.0
    mlp_drop: float = 0.0
    mlp_bias: bool = True

    def save_pretrained(self, json_path):
        with open(json_path, "w") as f:
            json.dump(vars(self), f)

    @classmethod
    def from_pretrained(cls, json_path):
        with open(json_path, "r") as f:
            config = json.load(f)

        return cls(**config)

    def shallow_copy(self):
        return GenieConfig(**vars(self))

    def __post_init__(self):
        if self.image_vocab_size == None:
            self.factored_vocab_size  = 64 # dummy
        else:
            self.factored_vocab_size = nth_root(self.image_vocab_size, self.num_factored_vocabs)


@dataclass
class DiffusionGenieConfig(GenieConfig):
    Diffusion: bool = True

    # Attention
    dim: int = 512
    dataloader_apply_mask: bool = True # apply mask inside the model
    dataloader_apply_corruption: bool =  False # no need for random corruptions
    dataloader_mask_ratio_min: float = 0.1

    # MLP
    vae_stride: int = 1
    patch_size: int = 1
    vae_embed_dim: int = 4
    mask_ratio_min: float = 0.7
    label_drop_prob: float = 0.1
    attn_dropout: float = 0.1
    proj_dropout: float = 0.1
    buffer_size: int = 64
    diffloss_d: int = 4
    diffloss_w: int = 1024 # 1024
    num_sampling_steps: str = '100'
    diffusion_batch_mul: int = 1
    grad_checkpointing: bool = False
    use_actions: bool = True
    jointly_predict_actions: bool = False # jointly predict actions
    jointly_predict_states: bool = True # jointly predict states
    action_token_size: int = 64 # images are 16x16
    label_drop_prob: float = 0.5 # the drop ratio for action tokens
    action_loss_weight: float = 1.0 # weight for action loss
    predict_unmask: bool = False # also predict tokens in unmasked regions
    maskgit_steps: int = 16 # the mask iterations during inference

    def shallow_copy(self):
        return DiffusionGenieConfig(**vars(self))

@dataclass
class CogVideoGenieConfig(GenieConfig):
    CogVideo: bool = True

    # Attention
    dim: int = 512
    num_attention_heads: int = 30
    attention_head_dim: int = 16
    time_embed_dim: int = 128

    # MLP
    mlp_ratio: float = 4.0
    mlp_drop: float = 0.0
    mlp_bias: bool = True