regent-creators
commited on
Upload model
Browse files- config.json +72 -0
- configuration_jat.py +134 -0
- generation_config.json +6 -0
- modeling_regent.py +751 -0
- pytorch_model.bin +3 -0
config.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ONLY_RL_TASKS": true,
|
3 |
+
"_name_or_path": "checkpoints/jat-regent-medium-10.0lamda-1.0MDM-1.0ADM-p95DN-resnet18_512ADT_embeddings/checkpoint-27726",
|
4 |
+
"action_loss_coef": 1.0,
|
5 |
+
"action_vocab_size": 18,
|
6 |
+
"activation_function": "gelu_new",
|
7 |
+
"architectures": [
|
8 |
+
"JatRegentModel"
|
9 |
+
],
|
10 |
+
"atari_dist_multiplier": 1.0,
|
11 |
+
"atari_dist_type": "resnet18_512",
|
12 |
+
"attention_dropout": 0.0,
|
13 |
+
"attention_layers": [
|
14 |
+
"global",
|
15 |
+
"local",
|
16 |
+
"global",
|
17 |
+
"local",
|
18 |
+
"global",
|
19 |
+
"local",
|
20 |
+
"global",
|
21 |
+
"local",
|
22 |
+
"global",
|
23 |
+
"local",
|
24 |
+
"global",
|
25 |
+
"local"
|
26 |
+
],
|
27 |
+
"attention_types": [
|
28 |
+
[
|
29 |
+
[
|
30 |
+
"global",
|
31 |
+
"local"
|
32 |
+
],
|
33 |
+
6
|
34 |
+
]
|
35 |
+
],
|
36 |
+
"auto_map": {
|
37 |
+
"AutoConfig": "configuration_jat.JatConfig",
|
38 |
+
"AutoModelForCausalLM": "modeling_regent.JatRegentModel"
|
39 |
+
},
|
40 |
+
"bos_token_id": 50256,
|
41 |
+
"classifier_dropout": 0.1,
|
42 |
+
"dist_normalizer": "p95",
|
43 |
+
"embed_dropout": 0.0,
|
44 |
+
"eos_token_id": 50256,
|
45 |
+
"finetune_num_demos": null,
|
46 |
+
"hidden_size": 768,
|
47 |
+
"image_size": 224,
|
48 |
+
"initializer_range": 0.02,
|
49 |
+
"intermediate_size": null,
|
50 |
+
"lamda": 10.0,
|
51 |
+
"layer_norm_epsilon": 1e-05,
|
52 |
+
"max_continuous_size": 513,
|
53 |
+
"max_discrete_value": 212,
|
54 |
+
"max_position_embeddings": 40,
|
55 |
+
"model_type": "jat",
|
56 |
+
"mujoco_dist_multiplier": 1.0,
|
57 |
+
"num_channels": 3,
|
58 |
+
"num_contexts": 20,
|
59 |
+
"num_heads": 12,
|
60 |
+
"num_layers": 12,
|
61 |
+
"observation_loss_coef": 0.0,
|
62 |
+
"patch_size": 16,
|
63 |
+
"resid_dropout": 0.0,
|
64 |
+
"tokenizer_class": "GPT2TokenizerFast",
|
65 |
+
"torch_dtype": "float32",
|
66 |
+
"transformers_version": "4.41.2",
|
67 |
+
"use_atari_embeddings": true,
|
68 |
+
"use_cache": true,
|
69 |
+
"use_global_atari_actions": true,
|
70 |
+
"vocab_size": 50257,
|
71 |
+
"window_size": 256
|
72 |
+
}
|
configuration_jat.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPTNeoConfig
|
2 |
+
|
3 |
+
|
4 |
+
class JatConfig(GPTNeoConfig):
|
5 |
+
r"""
|
6 |
+
This is the configuration class to store the configuration of a [`JatModel`]. It is used to instantiate a Jat
|
7 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with
|
8 |
+
the defaults will yield a similar configuration to that of the ... (TODO)
|
9 |
+
|
10 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
11 |
+
documentation from [`PretrainedConfig`] for more information.
|
12 |
+
|
13 |
+
|
14 |
+
Args:
|
15 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
16 |
+
Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
|
17 |
+
`inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different
|
18 |
+
tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`].
|
19 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
20 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
21 |
+
just in case (e.g., 512 or 1024 or 2048).
|
22 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
23 |
+
Dimensionality of the encoder layers and the pooler layer.
|
24 |
+
num_layers (`int`, *optional*, defaults to 24):
|
25 |
+
Number of hidden layers in the Transformer encoder.
|
26 |
+
attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`):
|
27 |
+
The type of attention for each layer in a `List` of the following format `[[["attention_type"],
|
28 |
+
num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
|
29 |
+
value of `attention_type` from `["global", "local"]`
|
30 |
+
num_heads (`int`, *optional*, defaults to 16):
|
31 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
32 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
33 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
34 |
+
window_size (`int`, *optional*, defaults to 256):
|
35 |
+
The size of the sliding window for local attention.
|
36 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`):
|
37 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
38 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
39 |
+
resid_dropout (`float`, *optional*, defaults to 0.0):
|
40 |
+
Residual dropout used in the attention pattern.
|
41 |
+
embed_dropout (`float`, *optional*, defaults to 0.0):
|
42 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
43 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
44 |
+
The dropout ratio for the attention probabilities.
|
45 |
+
classifier_dropout (`float`, *optional*, defaults to 0.1):
|
46 |
+
Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The
|
47 |
+
dropout ratio for the hidden layer.
|
48 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
49 |
+
The epsilon used by the layer normalization layers.
|
50 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
51 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
52 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
53 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
54 |
+
relevant if `config.is_decoder=True`.
|
55 |
+
bos_token_id (`int`, *optional*, defaults to 50256):
|
56 |
+
The id of the beginning of sentence token in the vocabulary.
|
57 |
+
eos_token_id (`int`, *optional*, defaults to 50256):
|
58 |
+
The id of the end of sentence token in the vocabulary.
|
59 |
+
max_continuous_size (`int`, *optional*, default to 376):
|
60 |
+
The maximum size of the continuous values.
|
61 |
+
max_discrete_value (`int`, *optional*, default to 18):
|
62 |
+
The maximum value of the discrete values.
|
63 |
+
image_size (`int`, *optional*, defaults to 224):
|
64 |
+
The size (resolution) of each image.
|
65 |
+
patch_size (`int`, *optional*, defaults to 16):
|
66 |
+
The size (resolution) of each patch.
|
67 |
+
observation_loss_coef (`float`, *optional*, defaults to 0.005):
|
68 |
+
The coefficient for the observation loss. When set to 0.0, the observation is not even predicted.
|
69 |
+
action_loss_coef (`float`, *optional*, defaults to 0.995):
|
70 |
+
The coefficient for the action loss.
|
71 |
+
"""
|
72 |
+
|
73 |
+
model_type = "jat"
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
vocab_size=50257,
|
78 |
+
max_position_embeddings=2048,
|
79 |
+
hidden_size=2048,
|
80 |
+
num_layers=24,
|
81 |
+
attention_types=[[["global", "local"], 12]],
|
82 |
+
num_heads=16,
|
83 |
+
intermediate_size=None,
|
84 |
+
window_size=256,
|
85 |
+
activation_function="gelu_new",
|
86 |
+
resid_dropout=0.0,
|
87 |
+
embed_dropout=0.0,
|
88 |
+
attention_dropout=0.0,
|
89 |
+
classifier_dropout=0.1,
|
90 |
+
layer_norm_epsilon=1e-5,
|
91 |
+
initializer_range=0.02,
|
92 |
+
use_cache=True,
|
93 |
+
bos_token_id=50256,
|
94 |
+
eos_token_id=50256,
|
95 |
+
max_continuous_size=377,
|
96 |
+
max_discrete_value=18,
|
97 |
+
image_size=224,
|
98 |
+
num_channels=3,
|
99 |
+
patch_size=16,
|
100 |
+
observation_loss_coef=0.005,
|
101 |
+
action_loss_coef=0.995,
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
super().__init__(
|
105 |
+
vocab_size,
|
106 |
+
max_position_embeddings,
|
107 |
+
hidden_size,
|
108 |
+
num_layers,
|
109 |
+
attention_types,
|
110 |
+
num_heads,
|
111 |
+
intermediate_size,
|
112 |
+
window_size,
|
113 |
+
activation_function,
|
114 |
+
resid_dropout,
|
115 |
+
embed_dropout,
|
116 |
+
attention_dropout,
|
117 |
+
classifier_dropout,
|
118 |
+
layer_norm_epsilon,
|
119 |
+
initializer_range,
|
120 |
+
use_cache,
|
121 |
+
bos_token_id,
|
122 |
+
eos_token_id,
|
123 |
+
**kwargs,
|
124 |
+
)
|
125 |
+
self.max_continuous_size = max_continuous_size
|
126 |
+
self.max_discrete_value = max_discrete_value
|
127 |
+
self.image_size = image_size
|
128 |
+
self.num_channels = num_channels
|
129 |
+
self.patch_size = patch_size
|
130 |
+
self.observation_loss_coef = observation_loss_coef
|
131 |
+
self.action_loss_coef = action_loss_coef
|
132 |
+
|
133 |
+
|
134 |
+
JatConfig.register_for_auto_class()
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 50256,
|
4 |
+
"eos_token_id": 50256,
|
5 |
+
"transformers_version": "4.41.2"
|
6 |
+
}
|
modeling_regent.py
ADDED
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from gymnasium import spaces
|
9 |
+
from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
|
10 |
+
from transformers import GPTNeoModel, GPTNeoPreTrainedModel
|
11 |
+
from transformers.modeling_outputs import ModelOutput
|
12 |
+
from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
from jat.configuration_jat import JatConfig
|
17 |
+
from jat.processing_jat import JatProcessor
|
18 |
+
from jat.modeling_jat import JatModel, compute_mse_loss, cyclic_expand_dim, JatOutput
|
19 |
+
from regent.utils import build_index_vector, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_vector, myprint, L2dist, get_dist_stats, get_images_of_retrieved_obs, get_emb_transform_model_dim, get_optional_suffix
|
20 |
+
from regent.atari_utils import convert_local_to_global_action, convert_global_to_local_action
|
21 |
+
from regent.eval.rl import SEEN_TASK_NAME_TO_ENV_ID, UNSEEN_TASK_NAME_TO_ENV_ID
|
22 |
+
from PIL import Image
|
23 |
+
import os
|
24 |
+
from copy import deepcopy
|
25 |
+
from pytorch_msssim import ssim
|
26 |
+
import json
|
27 |
+
|
28 |
+
|
29 |
+
def cross_entropy_from_softmax(softmax_probs, targets, reduction="mean", epsilon=1e-9):
|
30 |
+
"""
|
31 |
+
Calculate the cross entropy loss given softmax_probs and targets.
|
32 |
+
|
33 |
+
:param softmax_probs: tensor containing softmax probabilities
|
34 |
+
:param targets: tensor containing the target classes (not one-hot encoded)
|
35 |
+
:return: cross entropy loss
|
36 |
+
"""
|
37 |
+
assert len(softmax_probs.shape) == 2, "softmax_probs should be of shape (batch_size, num_classes)"
|
38 |
+
assert len(targets.shape) == 1, "targets should be of shape (batch_size,)"
|
39 |
+
|
40 |
+
# Convert targets to one-hot encoding
|
41 |
+
targets_one_hot = F.one_hot(targets, num_classes=softmax_probs.shape[1]).float() # shape: (batch_size, num_classes)
|
42 |
+
|
43 |
+
# Calculate the cross entropy loss
|
44 |
+
softmax_probs = softmax_probs.clamp(min=epsilon, max=1-epsilon) # to avoid NaNs from log(0) and instabilities from log(1)
|
45 |
+
log_softmax_probs = softmax_probs.log() # safe to take log as softmax_probs are non-zero
|
46 |
+
loss = -torch.sum(targets_one_hot * log_softmax_probs, dim=1)
|
47 |
+
|
48 |
+
if reduction == "mean":
|
49 |
+
return loss.mean()
|
50 |
+
elif reduction == "sum":
|
51 |
+
return loss.sum()
|
52 |
+
elif reduction == "none":
|
53 |
+
return loss
|
54 |
+
else:
|
55 |
+
raise ValueError("reduction should be one of 'mean', 'sum', or 'none'")
|
56 |
+
|
57 |
+
|
58 |
+
def compute_ce_loss_from_softmax(
|
59 |
+
logits: FloatTensor, labels: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
|
60 |
+
) -> FloatTensor:
|
61 |
+
"""
|
62 |
+
Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
logits (`FloatTensor` of shape `(batch_size, max_seq_len, [inner_size,] num_classes)`):
|
66 |
+
Predicted logits at the output of the model.
|
67 |
+
labels (`torch.LongTensor` of shape `(batch_size, max_seq_len, [inner_size,])`):
|
68 |
+
Ground truth class labels.
|
69 |
+
mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
|
70 |
+
Boolean mask indicating valid timesteps.
|
71 |
+
weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
|
72 |
+
Weights to be applied to the loss.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
loss (`FloatTensor` of shape `(,)`):
|
76 |
+
CE loss between predicted logits and true class labels.
|
77 |
+
"""
|
78 |
+
if mask is not None:
|
79 |
+
logits = logits[mask.bool()] # (Y, X, C)
|
80 |
+
labels = labels[mask.bool()] # (Y, X)
|
81 |
+
if weights is not None:
|
82 |
+
weights = weights[mask.bool()] # (Y,)
|
83 |
+
else:
|
84 |
+
logits = logits.flatten(end_dim=2) # (B, L, X, C) -> (B*L, X, C)
|
85 |
+
labels = labels.flatten(end_dim=1) # (B, L, X) -> (B*L, X)
|
86 |
+
if weights is not None:
|
87 |
+
weights = weights.flatten(end_dim=1) # (B, L) -> (B*L,)
|
88 |
+
|
89 |
+
loss = cross_entropy_from_softmax(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none") # (Y*X,) # we don't use F.cross_entropy here to avoid double softmax
|
90 |
+
loss = loss.view(labels.size()) # (Y, X)
|
91 |
+
loss = loss.mean(-1) # (Y,)
|
92 |
+
|
93 |
+
# Multiply the loss by the weights
|
94 |
+
if weights is not None:
|
95 |
+
loss = loss * weights # (Y,)
|
96 |
+
|
97 |
+
# Average the loss
|
98 |
+
loss = loss.mean()
|
99 |
+
|
100 |
+
return loss
|
101 |
+
|
102 |
+
|
103 |
+
def crazy_relu(x, beta):
|
104 |
+
return nn.LeakyReLU(beta)(x) - (1-beta) * nn.ReLU()(x-1)
|
105 |
+
|
106 |
+
|
107 |
+
class JatRegentModel(JatModel):
|
108 |
+
"""
|
109 |
+
Jat Regent model.
|
110 |
+
"""
|
111 |
+
def __init__(self, config: JatConfig) -> None:
|
112 |
+
super().__init__(config)
|
113 |
+
hidden_size = config.hidden_size
|
114 |
+
action_vocab_size = config.action_vocab_size
|
115 |
+
|
116 |
+
if config.ONLY_RL_TASKS:
|
117 |
+
self.single_discrete_decoder = nn.Linear(hidden_size, action_vocab_size, bias=False)
|
118 |
+
self.N = config.action_vocab_size
|
119 |
+
else:
|
120 |
+
self.N = config.vocab_size
|
121 |
+
self.multi_discrete_decoder = None # not needed
|
122 |
+
self.image_decoder = None # not needed
|
123 |
+
self.num_contexts = config.num_contexts # used in get_next_action() at evaluation in an env only
|
124 |
+
self.lamda = config.lamda # used in get_next_action() at evaluation in an env only
|
125 |
+
self.use_global_atari_actions = config.use_global_atari_actions
|
126 |
+
self.dist_multipliers = {'mujoco': config.mujoco_dist_multiplier, 'atari': config.atari_dist_multiplier}
|
127 |
+
self.dist_normalizer = config.dist_normalizer
|
128 |
+
self.atari_dist_type = config.atari_dist_type
|
129 |
+
self.use_atari_embeddings = config.use_atari_embeddings
|
130 |
+
self.finetune_num_demos = config.finetune_num_demos if hasattr(config, 'finetune_num_demos') else None
|
131 |
+
if self.use_atari_embeddings:
|
132 |
+
self.image_encoder = None
|
133 |
+
self.emb_dim_full = (512,)
|
134 |
+
|
135 |
+
# print number of parameters
|
136 |
+
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
137 |
+
myprint(f"number of parameters: {num_params / 1e6:.4f}M")
|
138 |
+
|
139 |
+
def retrieval_setup(self,
|
140 |
+
task,
|
141 |
+
dataset,
|
142 |
+
num_demos, # to retrieve from
|
143 |
+
device,
|
144 |
+
batch_size_retrieval=16, # for atari envs on gpu
|
145 |
+
nb_cores_autofaiss=8, # for vector obs envs on cpu cores
|
146 |
+
):
|
147 |
+
# setup
|
148 |
+
rew_key, attn_key, obs_key, act_key, B, obs_dim, act_dim = get_task_info(task)
|
149 |
+
extra_key = 'discrete_RandP_action_logits' if task.startswith("atari") or task.startswith("babyai") else 'continuous_RandP_actions'
|
150 |
+
optional_suffix = get_optional_suffix(task, self.atari_dist_type, self.finetune_num_demos)
|
151 |
+
mean_dist, std_dist, max_dist, p80, p85, p90, p95, p99 = get_dist_stats(task=task, optional_suffix=optional_suffix)
|
152 |
+
|
153 |
+
# get embedding model
|
154 |
+
if task.startswith("atari"):
|
155 |
+
self.emb_transform, self.emb_model, emb_dim, self.emb_model_full = get_emb_transform_model_dim(self.atari_dist_type, self.device, return_emb_weights=True)
|
156 |
+
obs_dim = emb_dim # overwrite for atari_dist_type
|
157 |
+
|
158 |
+
kwargs = {'B': B,
|
159 |
+
'obs_dim': obs_dim,
|
160 |
+
'attn_key': attn_key,
|
161 |
+
'obs_key': obs_key,
|
162 |
+
'device': device,
|
163 |
+
'task': task,
|
164 |
+
'batch_size_retrieval': batch_size_retrieval,
|
165 |
+
'nb_cores_autofaiss': nb_cores_autofaiss,
|
166 |
+
'verbose': False,
|
167 |
+
'atari_dist_type': self.atari_dist_type,
|
168 |
+
}
|
169 |
+
raw_obs_dim = obs_dim
|
170 |
+
if task.startswith("atari"): # overwrite raw_obs_dim because raw obs in atari are (4, 84, 84) and raw obs in babyai have 64 extra dim
|
171 |
+
raw_obs_dim = (4, 84, 84)
|
172 |
+
elif task.startswith("babyai"):
|
173 |
+
raw_obs_dim = (obs_dim[0]+64,)
|
174 |
+
|
175 |
+
# save
|
176 |
+
self.task = task
|
177 |
+
self.dataset = dataset
|
178 |
+
self.obs_key = obs_key
|
179 |
+
self.act_key = act_key
|
180 |
+
self.rew_key = rew_key
|
181 |
+
self.attn_key = attn_key
|
182 |
+
self.obs_dim = obs_dim
|
183 |
+
self.act_dim = act_dim
|
184 |
+
self.extra_key = extra_key
|
185 |
+
self.kwargs = kwargs
|
186 |
+
self.raw_obs_dim = raw_obs_dim
|
187 |
+
self.max_dist = max_dist
|
188 |
+
self.mean_dist = mean_dist
|
189 |
+
self.std_dist = std_dist
|
190 |
+
self.p80, self.p85, self.p90, self.p95, self.p99 = p80, p85, p90, p95, p99
|
191 |
+
self.dist_normalizer_value = {'std': std_dist, 'max': max_dist, 'p80': p80, 'p85': p85, 'p90': p90, 'p95': p95, 'p99': p99}[self.dist_normalizer]
|
192 |
+
if self.dist_normalizer_value == 0.0: self.dist_normalizer_value = 1.0
|
193 |
+
|
194 |
+
# for retrieval,
|
195 |
+
all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs, all_datarows_dict = collect_all_data(dataset, task, obs_key, num_demos, return_datarows_dict=True, atari_dist_type=self.atari_dist_type)
|
196 |
+
if task.startswith("babyai"):
|
197 |
+
# for each mission in task,
|
198 |
+
self.all_indices = {}
|
199 |
+
self.knn_index = {}
|
200 |
+
for mission_idx, mission in enumerate(all_row_idxs.keys()):
|
201 |
+
# create index, collect subset of data that we can retrieve from
|
202 |
+
myprint(('*'*50) + f'{mission=} - {mission_idx+1}/{len(all_row_idxs.keys())}')
|
203 |
+
self.all_indices[mission], self.knn_index[mission] = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG[mission],
|
204 |
+
all_attn_masks_OG=all_attn_masks_OG[mission],
|
205 |
+
all_row_idxs=all_row_idxs[mission],
|
206 |
+
kwargs=kwargs)
|
207 |
+
else:
|
208 |
+
# create index, collect subset of data that we can retrieve from
|
209 |
+
self.all_indices, self.knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
|
210 |
+
all_attn_masks_OG=all_attn_masks_OG,
|
211 |
+
all_row_idxs=all_row_idxs,
|
212 |
+
kwargs=kwargs)
|
213 |
+
|
214 |
+
# for retrieval inside retrieve()
|
215 |
+
self.datarows = all_datarows_dict
|
216 |
+
|
217 |
+
|
218 |
+
# # for checking if first env state is similar to retrieval episode's first states
|
219 |
+
# if task.startswith("mujoco"):
|
220 |
+
# local_path = f"dataset_jat_regent/{task}"
|
221 |
+
# with open(f"{local_path}/eps_2_rows_tokenized.json", 'r') as f:
|
222 |
+
# eps_2_rows_tokenized = json.load(f)
|
223 |
+
# eps_2_rows_tokenized = {int(k): v for k, v in eps_2_rows_tokenized.items()}
|
224 |
+
# row_idxs_of_first_state_of_demos = [eps_2_rows_tokenized[eps][0] for eps in range(num_demos)]
|
225 |
+
# self.first_states_of_demos = [np.array(dataset['train'][row_idx][obs_key][0]) for row_idx in row_idxs_of_first_state_of_demos]
|
226 |
+
# else:
|
227 |
+
# self.first_states_of_demos = None
|
228 |
+
|
229 |
+
def output_rl(
|
230 |
+
self,
|
231 |
+
transformer_outputs,
|
232 |
+
continuous_observations: Optional[FloatTensor] = None,
|
233 |
+
discrete_observations: Optional[LongTensor] = None,
|
234 |
+
image_observations: Optional[FloatTensor] = None,
|
235 |
+
continuous_actions: Optional[FloatTensor] = None,
|
236 |
+
discrete_actions: Optional[LongTensor] = None,
|
237 |
+
rewards: Optional[FloatTensor] = None,
|
238 |
+
attention_mask: Optional[BoolTensor] = None,
|
239 |
+
return_loss: bool = True,
|
240 |
+
return_dict: Optional[bool] = None,
|
241 |
+
loss_weight: Optional[FloatTensor] = None,
|
242 |
+
exp_lamda_distances: Optional[FloatTensor] = None,
|
243 |
+
continuous_RandP_actions: Optional[FloatTensor] = None,
|
244 |
+
discrete_RandP_action_logits: Optional[FloatTensor] = None,
|
245 |
+
):
|
246 |
+
hidden_states = transformer_outputs.last_hidden_state
|
247 |
+
loss, observation_loss, action_loss = None, None, None
|
248 |
+
|
249 |
+
# Observations
|
250 |
+
assert rewards is not None
|
251 |
+
observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None
|
252 |
+
assert self.observation_loss_coef == 0.0, f'{self.observation_loss_coef=} should be 0.0 as we are not predicting observations!'
|
253 |
+
# warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
|
254 |
+
pred_observations = None
|
255 |
+
observation_loss = 0.0
|
256 |
+
|
257 |
+
# Actions
|
258 |
+
actions_mask = attention_mask[:, ::2] if attention_mask is not None else None
|
259 |
+
if continuous_actions is not None:
|
260 |
+
act_size = continuous_actions.shape[-1]
|
261 |
+
continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)
|
262 |
+
continuous_RandP_actions = cyclic_expand_dim(continuous_RandP_actions, self.config.max_continuous_size)
|
263 |
+
init_pred_actions = self.continuous_decoder(hidden_states[:, ::2])
|
264 |
+
pred_actions = self.continuous_action_interpolation(init_pred_actions, exp_lamda_distances, continuous_RandP_actions, beta=0.0)
|
265 |
+
if return_loss:
|
266 |
+
action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight) # loss_weight is usually 50 for metaworld, 10 for mujoco (except two tasks where it is 20, 50), 1 for the rest!
|
267 |
+
pred_actions = pred_actions[..., :act_size]
|
268 |
+
elif discrete_actions is not None:
|
269 |
+
init_pred_actions = self.single_discrete_decoder(hidden_states[:, ::2])
|
270 |
+
pred_actions = self.discrete_action_interpolation(init_pred_actions, exp_lamda_distances, discrete_RandP_action_logits, beta=0.0)
|
271 |
+
if return_loss:
|
272 |
+
action_loss = compute_ce_loss_from_softmax(pred_actions, discrete_actions, actions_mask, weights=loss_weight)
|
273 |
+
|
274 |
+
# Return output
|
275 |
+
if return_loss:
|
276 |
+
loss = self.observation_loss_coef * observation_loss + self.action_loss_coef * action_loss
|
277 |
+
|
278 |
+
if not return_dict:
|
279 |
+
output = (pred_observations, pred_actions) + transformer_outputs[1:]
|
280 |
+
return ((loss, observation_loss, action_loss) + output) if loss is not None else output
|
281 |
+
|
282 |
+
return JatOutput(
|
283 |
+
loss=loss,
|
284 |
+
observation_loss=observation_loss,
|
285 |
+
action_loss=action_loss,
|
286 |
+
pred_observations=pred_observations,
|
287 |
+
pred_actions=pred_actions,
|
288 |
+
past_key_values=transformer_outputs.past_key_values,
|
289 |
+
hidden_states=transformer_outputs.hidden_states,
|
290 |
+
attentions=transformer_outputs.attentions,
|
291 |
+
)
|
292 |
+
|
293 |
+
def shifted_crazy_relu(self, x, beta):
|
294 |
+
return 2 * crazy_relu(0.5*(x+1), beta) - 1
|
295 |
+
|
296 |
+
def continuous_action_interpolation(self, init_pred_actions, exp_lamda_distances, continuous_RandP_actions, beta=0.0):
|
297 |
+
batch_size, max_seq_len, act_size = init_pred_actions.shape
|
298 |
+
assert (init_pred_actions.shape == (batch_size, max_seq_len, act_size) and
|
299 |
+
exp_lamda_distances.shape == (batch_size, max_seq_len, 1) and
|
300 |
+
continuous_RandP_actions.shape == (batch_size, max_seq_len, act_size)), f'{init_pred_actions.shape=}, {exp_lamda_distances.shape=}, {continuous_RandP_actions.shape=}, {(batch_size, max_seq_len, act_size)=}'
|
301 |
+
|
302 |
+
""" MCNN interpolation (https://arxiv.org/abs/2310.06171) """
|
303 |
+
act_fn = self.shifted_crazy_relu
|
304 |
+
final_actions = exp_lamda_distances * continuous_RandP_actions + 10.0 * (1 - exp_lamda_distances) * act_fn(init_pred_actions, beta=beta)
|
305 |
+
return final_actions
|
306 |
+
|
307 |
+
def discrete_action_interpolation(self, init_pred_actions, exp_lamda_distances, discrete_RandP_action_logits, beta=0.0):
|
308 |
+
batch_size, max_seq_len, action_vocab_size = init_pred_actions.shape
|
309 |
+
assert (init_pred_actions.shape == (batch_size, max_seq_len, action_vocab_size) and
|
310 |
+
exp_lamda_distances.shape == (batch_size, max_seq_len, 1) and
|
311 |
+
discrete_RandP_action_logits.shape == (batch_size, max_seq_len, action_vocab_size)), f'{init_pred_actions.shape=}, {exp_lamda_distances.shape=}, {discrete_RandP_action_logits.shape=}, {(batch_size, max_seq_len, action_vocab_size)=}'
|
312 |
+
|
313 |
+
""" MCNN-like interpolation """
|
314 |
+
# print(f'{torch.round(discrete_RandP_action_logits[:, -1],decimals=2)=}')
|
315 |
+
# print(f'{torch.round(F.softmax(init_pred_actions, dim=-1)[:, -1],decimals=2)=}')
|
316 |
+
# print(f'{torch.round(exp_lamda_distances[:, -1],decimals=2)=}')
|
317 |
+
# print(f'first term: {torch.round((exp_lamda_distances * discrete_RandP_action_logits)[:, -1],decimals=2)}')
|
318 |
+
# print(f'second term: {torch.round(((1 - exp_lamda_distances) * F.softmax(init_pred_actions, dim=-1))[:, -1],decimals=2)}')
|
319 |
+
final_actions = exp_lamda_distances * discrete_RandP_action_logits + (1 - exp_lamda_distances) * F.softmax(init_pred_actions, dim=-1)
|
320 |
+
return final_actions
|
321 |
+
|
322 |
+
# Copied the forward function from the Parent class with the addition of the last 3 args in the input args and in output_rl args
|
323 |
+
def forward(
|
324 |
+
self,
|
325 |
+
input_ids: Optional[LongTensor] = None,
|
326 |
+
pixel_values: Optional[FloatTensor] = None,
|
327 |
+
continuous_observations: Optional[FloatTensor] = None,
|
328 |
+
discrete_observations: Optional[LongTensor] = None,
|
329 |
+
image_observations: Optional[FloatTensor] = None,
|
330 |
+
continuous_actions: Optional[FloatTensor] = None,
|
331 |
+
discrete_actions: Optional[LongTensor] = None,
|
332 |
+
rewards: Optional[FloatTensor] = None,
|
333 |
+
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None,
|
334 |
+
attention_mask: Optional[BoolTensor] = None,
|
335 |
+
token_type_ids: Optional[LongTensor] = None,
|
336 |
+
position_ids: Optional[LongTensor] = None,
|
337 |
+
return_loss: bool = True,
|
338 |
+
use_cache: Optional[bool] = None,
|
339 |
+
output_attentions: Optional[bool] = None,
|
340 |
+
output_hidden_states: Optional[bool] = None,
|
341 |
+
return_dict: Optional[bool] = None,
|
342 |
+
loss_weight: Optional[FloatTensor] = None,
|
343 |
+
exp_lamda_distances: Optional[FloatTensor] = None,
|
344 |
+
continuous_RandP_actions: Optional[FloatTensor] = None,
|
345 |
+
discrete_RandP_action_logits: Optional[FloatTensor] = None,
|
346 |
+
) -> JatOutput:
|
347 |
+
|
348 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
349 |
+
|
350 |
+
# Textual tasks
|
351 |
+
if input_ids is not None or pixel_values is not None:
|
352 |
+
inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask)
|
353 |
+
# RL tasks
|
354 |
+
elif (
|
355 |
+
continuous_observations is not None or discrete_observations is not None or image_observations is not None
|
356 |
+
):
|
357 |
+
inputs_embeds, attention_mask = self.embed_rl(
|
358 |
+
continuous_observations,
|
359 |
+
discrete_observations,
|
360 |
+
image_observations,
|
361 |
+
continuous_actions,
|
362 |
+
discrete_actions,
|
363 |
+
rewards,
|
364 |
+
attention_mask,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
raise ValueError("Input not provided.")
|
368 |
+
|
369 |
+
# Pass through transformer
|
370 |
+
transformer_outputs = self.transformer(
|
371 |
+
past_key_values=past_key_values,
|
372 |
+
attention_mask=attention_mask,
|
373 |
+
token_type_ids=token_type_ids,
|
374 |
+
position_ids=position_ids,
|
375 |
+
inputs_embeds=inputs_embeds,
|
376 |
+
use_cache=use_cache,
|
377 |
+
output_attentions=output_attentions,
|
378 |
+
output_hidden_states=output_hidden_states,
|
379 |
+
return_dict=return_dict,
|
380 |
+
)
|
381 |
+
|
382 |
+
if input_ids is not None or pixel_values is not None:
|
383 |
+
return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict)
|
384 |
+
else:
|
385 |
+
return self.output_rl(
|
386 |
+
transformer_outputs,
|
387 |
+
continuous_observations,
|
388 |
+
discrete_observations,
|
389 |
+
image_observations,
|
390 |
+
continuous_actions,
|
391 |
+
discrete_actions,
|
392 |
+
rewards,
|
393 |
+
attention_mask,
|
394 |
+
return_loss,
|
395 |
+
return_dict,
|
396 |
+
loss_weight,
|
397 |
+
exp_lamda_distances,
|
398 |
+
continuous_RandP_actions,
|
399 |
+
discrete_RandP_action_logits,
|
400 |
+
)
|
401 |
+
|
402 |
+
|
403 |
+
def reset_rl(self):
|
404 |
+
self.steps = 0
|
405 |
+
|
406 |
+
def process(
|
407 |
+
self,
|
408 |
+
processor: JatProcessor,
|
409 |
+
continuous_observation: Optional[List[float]] = None,
|
410 |
+
discrete_observation: Optional[List[int]] = None,
|
411 |
+
text_observation: Optional[str] = None,
|
412 |
+
image_observation: Optional[np.ndarray] = None,
|
413 |
+
action_space: Union[spaces.Box, spaces.Discrete] = None,
|
414 |
+
reward: Optional[float] = None,
|
415 |
+
deterministic: bool = True,
|
416 |
+
context_window: Optional[int] = None,
|
417 |
+
):
|
418 |
+
# Get the maximum sequence length
|
419 |
+
max_length = self.config.max_position_embeddings // 2
|
420 |
+
|
421 |
+
# Get the maximum sequence length
|
422 |
+
### see script/train_jat.py > L161.
|
423 |
+
### None ==> value set to 512 in jat/processing_jat.py > L354 and then // 2 in L355.
|
424 |
+
### weirdly, the value in script/eval_jat.py is set as 256 so it will be // 2 again in L355.
|
425 |
+
# max_length = 64 if self.task.startswith("atari") else None
|
426 |
+
|
427 |
+
# Convert everything to lists
|
428 |
+
def to_list(x):
|
429 |
+
return x.tolist() if isinstance(x, np.ndarray) else x
|
430 |
+
|
431 |
+
continuous_observation = to_list(continuous_observation)
|
432 |
+
discrete_observation = to_list(discrete_observation)
|
433 |
+
|
434 |
+
# get babyai mission within task
|
435 |
+
if self.task.startswith("babyai"):
|
436 |
+
mission = deepcopy(text_observation)
|
437 |
+
assert mission in self.knn_index.keys(), f'{mission=} should be in {self.knn_index.keys()=}'
|
438 |
+
|
439 |
+
# Add a fake action to the end of the sequence
|
440 |
+
if isinstance(action_space, spaces.Box):
|
441 |
+
fake_continuous_action = [0.0 for _ in range(action_space.shape[0])]
|
442 |
+
fake_discrete_action = None
|
443 |
+
elif isinstance(action_space, spaces.Discrete):
|
444 |
+
fake_continuous_action = None
|
445 |
+
fake_discrete_action = 0
|
446 |
+
|
447 |
+
continuous_observations = [continuous_observation] if continuous_observation is not None else None
|
448 |
+
discrete_observations = [discrete_observation] if discrete_observation is not None else None
|
449 |
+
text_observations = [text_observation] if text_observation is not None else None
|
450 |
+
image_observations = [image_observation] if image_observation is not None else None
|
451 |
+
continuous_actions = [fake_continuous_action] if fake_continuous_action is not None else None
|
452 |
+
discrete_actions = [fake_discrete_action] if fake_discrete_action is not None else None
|
453 |
+
rewards = [reward] if reward is not None else [0.0]
|
454 |
+
|
455 |
+
# Add the batch dimension
|
456 |
+
continuous_observations = [continuous_observations] if continuous_observations is not None else None
|
457 |
+
discrete_observations = [discrete_observations] if discrete_observations is not None else None
|
458 |
+
text_observations = [text_observations] if text_observations is not None else None
|
459 |
+
image_observations = [image_observations] if image_observations is not None else None
|
460 |
+
continuous_actions = [continuous_actions] if continuous_actions is not None else None
|
461 |
+
discrete_actions = [discrete_actions] if discrete_actions is not None else None
|
462 |
+
rewards = [rewards]
|
463 |
+
|
464 |
+
# Process the inputs
|
465 |
+
processed = processor(
|
466 |
+
continuous_observations=continuous_observations,
|
467 |
+
discrete_observations=discrete_observations,
|
468 |
+
text_observations=text_observations,
|
469 |
+
image_observations=image_observations,
|
470 |
+
continuous_actions=continuous_actions,
|
471 |
+
discrete_actions=discrete_actions,
|
472 |
+
rewards=rewards,
|
473 |
+
truncation=True,
|
474 |
+
truncation_side="left",
|
475 |
+
max_length=max_length,
|
476 |
+
return_tensors="pt",
|
477 |
+
)
|
478 |
+
|
479 |
+
assert (((self.act_key == 'continuous_actions' and processed[self.act_key].shape == (1, 1, self.act_dim)) or # zeros
|
480 |
+
(self.act_key == 'discrete_actions' and processed[self.act_key].shape == (1, 1))) and
|
481 |
+
processed[self.obs_key].shape == (1, 1, *self.raw_obs_dim) and
|
482 |
+
processed[self.rew_key].shape == (1, 1)), f'{processed[self.act_key].shape=}, {processed[self.obs_key].shape=}, {processed[self.rew_key].shape=}, {self.act_dim=}, {self.raw_obs_dim=}'
|
483 |
+
|
484 |
+
# save babyai mission
|
485 |
+
if self.task.startswith("babyai"):
|
486 |
+
processed['mission'] = mission
|
487 |
+
|
488 |
+
# save action_space and deterministic
|
489 |
+
processed['action_space'] = action_space
|
490 |
+
processed['deterministic'] = deterministic
|
491 |
+
|
492 |
+
return processed
|
493 |
+
|
494 |
+
def retrieve(
|
495 |
+
self,
|
496 |
+
all_processed: List[dict],
|
497 |
+
num_to_retrieve: int,
|
498 |
+
):
|
499 |
+
self.steps += 1
|
500 |
+
# Set num envs
|
501 |
+
num_envs = len(all_processed)
|
502 |
+
|
503 |
+
# Get obs from processed and make batch
|
504 |
+
row_of_obs = [all_processed[idx][self.obs_key][0].numpy() for idx in range(num_envs)]
|
505 |
+
row_of_obs = np.concatenate(row_of_obs)
|
506 |
+
assert row_of_obs.shape == (num_envs, *self.raw_obs_dim) and isinstance(row_of_obs, np.ndarray)
|
507 |
+
if self.task.startswith("atari"):
|
508 |
+
row_of_obs = process_row_of_obs_atari_full_without_mask(row_of_obs)
|
509 |
+
row_of_obs = torch.from_numpy(row_of_obs).to(self.device)
|
510 |
+
with torch.no_grad():
|
511 |
+
row_of_obs = self.emb_model(self.emb_transform(row_of_obs)).cpu().numpy()
|
512 |
+
elif self.task.startswith("babyai"):
|
513 |
+
row_of_obs = row_of_obs[:, :148] # removing last 64 text tokens
|
514 |
+
assert row_of_obs.shape == (num_envs, *self.obs_dim) and isinstance(row_of_obs, np.ndarray)
|
515 |
+
|
516 |
+
# Retrieve indices
|
517 |
+
if self.task.startswith("babyai"):
|
518 |
+
retrieved_indices = []
|
519 |
+
for idx in range(num_envs):
|
520 |
+
mission = all_processed[idx]['mission']
|
521 |
+
retrieved_indices_mission = retrieve_vector(row_of_obs=row_of_obs[idx:idx+1],
|
522 |
+
knn_index=self.knn_index[mission],
|
523 |
+
all_indices=self.all_indices[mission],
|
524 |
+
num_to_retrieve=num_to_retrieve,
|
525 |
+
kwargs=self.kwargs)
|
526 |
+
retrieved_indices.append(retrieved_indices_mission) # appending (1, 1, 2)
|
527 |
+
retrieved_indices = np.concatenate(retrieved_indices, axis=0)
|
528 |
+
assert retrieved_indices.shape == (num_envs, num_to_retrieve, 2)
|
529 |
+
else:
|
530 |
+
retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
|
531 |
+
knn_index=self.knn_index,
|
532 |
+
all_indices=self.all_indices,
|
533 |
+
num_to_retrieve=num_to_retrieve,
|
534 |
+
kwargs=self.kwargs)
|
535 |
+
|
536 |
+
# Return action
|
537 |
+
all_retrieved_act = []
|
538 |
+
all_retrieved_obs = []
|
539 |
+
all_retrieved_rew = []
|
540 |
+
env_idx = 0
|
541 |
+
for all_row_idx_and_i in retrieved_indices:
|
542 |
+
all_retrieved_act.append([])
|
543 |
+
all_retrieved_obs.append([])
|
544 |
+
all_retrieved_rew.append([])
|
545 |
+
for row_idx, i in all_row_idx_and_i:
|
546 |
+
if self.task.startswith("babyai"):
|
547 |
+
mission = all_processed[env_idx]['mission']
|
548 |
+
datarow = self.datarows[mission][int(row_idx)]
|
549 |
+
else:
|
550 |
+
datarow = self.datarows[int(row_idx)]
|
551 |
+
temp_a = datarow[self.act_key][int(i)]
|
552 |
+
if self.task.startswith("atari") and self.use_global_atari_actions:
|
553 |
+
temp_a = convert_local_to_global_action( temp_a, self.task )
|
554 |
+
all_retrieved_act[-1].append(temp_a)
|
555 |
+
all_retrieved_obs[-1].append(datarow[self.obs_key][int(i)])
|
556 |
+
all_retrieved_rew[-1].append(datarow[self.rew_key][int(i)])
|
557 |
+
env_idx += 1
|
558 |
+
|
559 |
+
return all_retrieved_act, all_retrieved_obs, all_retrieved_rew, row_of_obs
|
560 |
+
|
561 |
+
def get_distances(
|
562 |
+
self,
|
563 |
+
all_retrieved_obs: np.ndarray,
|
564 |
+
all_processed: List[dict],
|
565 |
+
query_obs: np.ndarray,
|
566 |
+
):
|
567 |
+
num_envs = len(all_processed)
|
568 |
+
|
569 |
+
# Process retrieved obs like in retrieve
|
570 |
+
num_contexts = all_retrieved_obs.shape[1] + 1
|
571 |
+
assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.raw_obs_dim) and isinstance(all_retrieved_obs, np.ndarray)
|
572 |
+
if self.task.startswith("atari"):
|
573 |
+
all_retrieved_obs = all_retrieved_obs.reshape(num_envs * (num_contexts - 1), *self.raw_obs_dim)
|
574 |
+
all_retrieved_obs = process_row_of_obs_atari_full_without_mask(all_retrieved_obs)
|
575 |
+
all_retrieved_obs = torch.from_numpy(all_retrieved_obs).to(self.device)
|
576 |
+
with torch.no_grad():
|
577 |
+
all_retrieved_obs = self.emb_model(self.emb_transform(all_retrieved_obs)).cpu().numpy()
|
578 |
+
all_retrieved_obs = all_retrieved_obs.reshape(num_envs, num_contexts - 1, *self.obs_dim)
|
579 |
+
elif self.task.startswith("babyai"):
|
580 |
+
all_retrieved_obs = all_retrieved_obs[:, :, :148]
|
581 |
+
assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.obs_dim) and isinstance(all_retrieved_obs, np.ndarray)
|
582 |
+
|
583 |
+
# Compute distances
|
584 |
+
all_distances = []
|
585 |
+
for idx in range(num_envs):
|
586 |
+
first_state = all_retrieved_obs[idx, 0:1]
|
587 |
+
distances = [0.0]
|
588 |
+
for i in range(1, num_contexts - 1):
|
589 |
+
curr_state = all_retrieved_obs[idx, i:i+1]
|
590 |
+
dist = L2dist(first_state, curr_state)
|
591 |
+
distances.append(dist)
|
592 |
+
curr_state = query_obs[idx:idx+1]
|
593 |
+
dist = L2dist(first_state, curr_state)
|
594 |
+
distances.append(dist)
|
595 |
+
all_distances.append(distances)
|
596 |
+
all_distances = np.array(all_distances)
|
597 |
+
assert all_distances.shape == (num_envs, num_contexts), f'{all_distances.shape=}, {num_envs=}, {num_contexts=}'
|
598 |
+
|
599 |
+
# distances: divide by std
|
600 |
+
all_distances = all_distances / self.dist_normalizer_value
|
601 |
+
if self.task.startswith("mujoco"):
|
602 |
+
all_distances = all_distances * self.dist_multipliers['mujoco']
|
603 |
+
elif self.task.startswith("atari"):
|
604 |
+
all_distances = all_distances * self.dist_multipliers['atari']
|
605 |
+
print(f'{self.dist_normalizer_value=}')
|
606 |
+
print(f'{all_distances=}')
|
607 |
+
|
608 |
+
return all_distances
|
609 |
+
|
610 |
+
@torch.no_grad()
|
611 |
+
def get_next_action(
|
612 |
+
self,
|
613 |
+
all_processed: List[dict],
|
614 |
+
return_retrieved_obs: bool = False,
|
615 |
+
):
|
616 |
+
num_envs = len(all_processed)
|
617 |
+
num_contexts = self.num_contexts
|
618 |
+
|
619 |
+
# Get the retrieved data
|
620 |
+
all_retrieved_act, all_retrieved_obs, all_retrieved_rew, row_of_obs = self.retrieve(all_processed, num_to_retrieve=num_contexts - 1)
|
621 |
+
if return_retrieved_obs:
|
622 |
+
all_retrieved_images = get_images_of_retrieved_obs(deepcopy(all_retrieved_obs), self.task)
|
623 |
+
|
624 |
+
# Get the distances
|
625 |
+
all_retrieved_obs = np.stack(all_retrieved_obs).astype(np.int32 if self.obs_key == 'discrete_observations' else np.float32)
|
626 |
+
assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.raw_obs_dim), f'{all_retrieved_obs.shape=}, {num_envs=}, {self.raw_obs_dim=}, {num_contexts-1=}'
|
627 |
+
all_distances = self.get_distances(all_retrieved_obs=all_retrieved_obs, all_processed=all_processed, query_obs=row_of_obs)
|
628 |
+
|
629 |
+
# Batch retrieved data
|
630 |
+
all_retrieved_act = np.stack(all_retrieved_act).astype(np.int32 if self.act_key == 'discrete_actions' else np.float32)
|
631 |
+
all_retrieved_rew = np.stack(all_retrieved_rew).astype(np.float32)
|
632 |
+
assert (((self.act_key == 'continuous_actions' and all_retrieved_act.shape == (num_envs, num_contexts - 1, self.act_dim)) or
|
633 |
+
(self.act_key == 'discrete_actions' and all_retrieved_act.shape == (num_envs, num_contexts - 1))) and
|
634 |
+
all_retrieved_rew.shape == (num_envs, num_contexts - 1)), f'{all_retrieved_act.shape=}, {all_retrieved_rew.shape=}, {num_envs=}, {self.act_dim=}, {self.raw_obs_dim=}, {num_contexts-1=}'
|
635 |
+
|
636 |
+
# Batch query data (already tensors) # query data is already int32/float32 after processing
|
637 |
+
all_query_act = torch.stack([all_processed[idx][self.act_key][0] for idx in range(num_envs)])
|
638 |
+
all_query_obs = np.stack([all_processed[idx][self.obs_key][0] for idx in range(num_envs)])
|
639 |
+
all_query_rew = torch.stack([all_processed[idx][self.rew_key][0] for idx in range(num_envs)])
|
640 |
+
assert (((self.act_key == 'continuous_actions' and all_query_act.shape == (num_envs, 1, self.act_dim)) or
|
641 |
+
(self.act_key == 'discrete_actions' and all_query_act.shape == (num_envs, 1))) and
|
642 |
+
all_query_obs.shape == (num_envs, 1, *self.raw_obs_dim) and
|
643 |
+
all_query_rew.shape == (num_envs, 1)), f'{all_query_act.shape=}, {all_query_obs.shape=}, {all_query_rew.shape=}, {num_envs=}, {self.act_dim=}, {self.raw_obs_dim=}'
|
644 |
+
|
645 |
+
# Collect attn
|
646 |
+
attn_weights = np.ones((num_envs, num_contexts)).astype(np.float32)
|
647 |
+
|
648 |
+
# Compute exp_lamda_distances
|
649 |
+
exp_lamda_distances = np.exp(-self.lamda * all_distances)[:, :, np.newaxis]
|
650 |
+
assert exp_lamda_distances.shape == (num_envs, num_contexts, 1), f'{exp_lamda_distances.shape=}, {num_envs=}, {num_contexts=}'
|
651 |
+
|
652 |
+
# Compute extra_key
|
653 |
+
all_extra_key = []
|
654 |
+
for idx in range(num_envs):
|
655 |
+
RandP_action = all_retrieved_act[idx, 0]
|
656 |
+
if self.extra_key == 'continuous_RandP_actions':
|
657 |
+
extra_key = [RandP_action for _ in range(num_contexts)]
|
658 |
+
elif self.extra_key == 'discrete_RandP_action_logits':
|
659 |
+
extra_key = []
|
660 |
+
for d in all_distances[idx]:
|
661 |
+
d = min(1.0, max(0.0, d))
|
662 |
+
curr_logits = [1.0/self.N * d for _ in range(self.N)]
|
663 |
+
curr_logits[RandP_action] = (1.0 + (self.N - 1.0)*(1.0 - d))/self.N
|
664 |
+
extra_key.append(curr_logits)
|
665 |
+
extra_key = np.stack(extra_key)
|
666 |
+
all_extra_key.append(extra_key)
|
667 |
+
all_extra_key = np.stack(all_extra_key).astype(np.float32)
|
668 |
+
|
669 |
+
if self.extra_key == 'continuous_RandP_actions':
|
670 |
+
assert all_extra_key.shape == (num_envs, num_contexts, self.act_dim), f'{all_extra_key.shape=}, {num_envs=}, {num_contexts=}, {self.act_dim=}'
|
671 |
+
elif self.extra_key == 'discrete_RandP_action_logits':
|
672 |
+
assert all_extra_key.shape == (num_envs, num_contexts, self.N), f'{all_extra_key.shape=}, {num_envs=}, {num_contexts=}, {self.N=}'
|
673 |
+
|
674 |
+
# Tensorify
|
675 |
+
all_retrieved_act = torch.from_numpy(all_retrieved_act)
|
676 |
+
all_retrieved_rew = torch.from_numpy(all_retrieved_rew)
|
677 |
+
attn_weights = torch.from_numpy(attn_weights).to(self.device)
|
678 |
+
exp_lamda_distances = torch.from_numpy(exp_lamda_distances).to(self.device)
|
679 |
+
all_extra_key = torch.from_numpy(all_extra_key).to(self.device)
|
680 |
+
|
681 |
+
# Concat retrieved and query batches
|
682 |
+
all_act = torch.cat([all_retrieved_act, all_query_act], dim=1).to(self.device)
|
683 |
+
all_obs = np.concatenate([all_retrieved_obs, all_query_obs], axis=1)
|
684 |
+
if self.use_atari_embeddings and self.task.startswith("atari"):
|
685 |
+
all_obs = all_obs.reshape(num_envs * num_contexts, *self.raw_obs_dim)
|
686 |
+
all_obs = process_row_of_obs_atari_full_without_mask(all_obs)
|
687 |
+
all_obs = torch.from_numpy(all_obs).to(self.device)
|
688 |
+
with torch.no_grad():
|
689 |
+
all_obs = self.emb_model_full(self.emb_transform(all_obs)).reshape(num_envs, num_contexts, *self.emb_dim_full)
|
690 |
+
else:
|
691 |
+
all_obs = torch.from_numpy(all_obs).to(self.device)
|
692 |
+
all_rew = torch.cat([all_retrieved_rew, all_query_rew], dim=1).to(self.device)
|
693 |
+
|
694 |
+
# Collect action_space, deterministic from all_processed
|
695 |
+
all_action_space = [all_processed[idx]['action_space'] for idx in range(num_envs)]
|
696 |
+
all_deterministic = [all_processed[idx]['deterministic'] for idx in range(num_envs)]
|
697 |
+
## assert that all action_space and deterministic are same for all envs
|
698 |
+
assert all([action_space == all_action_space[0] for action_space in all_action_space]), f'{all_action_space=}'
|
699 |
+
assert all([deterministic == all_deterministic[0] for deterministic in all_deterministic]), f'{all_deterministic=}'
|
700 |
+
## then just use first one!
|
701 |
+
action_space = all_action_space[0]
|
702 |
+
deterministic = all_deterministic[0]
|
703 |
+
|
704 |
+
# Forward pass
|
705 |
+
if self.use_atari_embeddings and self.task.startswith("atari"):
|
706 |
+
final_obs_key = 'continuous_observations'
|
707 |
+
else:
|
708 |
+
final_obs_key = self.obs_key
|
709 |
+
outputs = self.forward(**{final_obs_key: all_obs,
|
710 |
+
self.act_key: all_act,
|
711 |
+
self.rew_key: all_rew,
|
712 |
+
self.attn_key: attn_weights,
|
713 |
+
'exp_lamda_distances': exp_lamda_distances,
|
714 |
+
self.extra_key: all_extra_key,
|
715 |
+
}, return_loss=False)
|
716 |
+
|
717 |
+
# Return the predicted action
|
718 |
+
if self.act_key == 'continuous_actions':
|
719 |
+
self.last_continuous_action = outputs.pred_actions[:, -1].cpu().numpy()
|
720 |
+
|
721 |
+
assert self.last_continuous_action.shape == (num_envs, self.act_dim), f'{self.last_continuous_action.shape=}, {num_envs=}, {self.act_dim=}'
|
722 |
+
|
723 |
+
myprint(f'L2dist(RandP action, Pred action): {[L2dist(all_retrieved_act[idx, 0].cpu().numpy(), self.last_continuous_action[idx]) for idx in range(num_envs)]}')
|
724 |
+
self.last_continuous_action = list(self.last_continuous_action) # list of arrays
|
725 |
+
return self.last_continuous_action if not return_retrieved_obs else (self.last_continuous_action, all_retrieved_images)
|
726 |
+
|
727 |
+
elif self.act_key == 'discrete_actions':
|
728 |
+
act_n = self.config.action_vocab_size if (self.task.startswith('atari') and self.use_global_atari_actions) else action_space.n
|
729 |
+
logits = outputs.pred_actions[:, -1, : act_n]
|
730 |
+
assert logits.shape == (num_envs, act_n), f'{logits.shape=}, {num_envs=}, {act_n=}'
|
731 |
+
if deterministic:
|
732 |
+
# myprint(f'{all_extra_key[:, -1, : action_space.n]=}')
|
733 |
+
# myprint(f'{logits=}')
|
734 |
+
self.last_discrete_action = logits.argmax(dim=-1, keepdim=True).cpu().numpy().reshape(-1)
|
735 |
+
else: # sample
|
736 |
+
self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1).cpu().numpy().reshape(-1)
|
737 |
+
|
738 |
+
assert self.last_discrete_action.shape == (num_envs,), f'{self.last_discrete_action.shape=}, {num_envs=}'
|
739 |
+
|
740 |
+
self.last_discrete_action = list(self.last_discrete_action) # list of ints
|
741 |
+
myprint(f'RandP action: {all_retrieved_act[:, 0].cpu().numpy().tolist()} vs Pred action: {self.last_discrete_action}')
|
742 |
+
|
743 |
+
if self.task.startswith("atari") and self.use_global_atari_actions:
|
744 |
+
self.last_discrete_action = [convert_global_to_local_action(a, self.task) for a in self.last_discrete_action]
|
745 |
+
myprint(f'[IN LOCAL ACTION] RandP action: {[convert_global_to_local_action(a, self.task) for a in all_retrieved_act[:, 0].cpu().numpy().tolist()]} vs Pred action: {self.last_discrete_action}')
|
746 |
+
myprint(f'[IN LOCAL ACTION] diff: {[convert_global_to_local_action(a, self.task) - b for a, b in zip(all_retrieved_act[:, 0].cpu().numpy().tolist(), self.last_discrete_action)]}')
|
747 |
+
|
748 |
+
return self.last_discrete_action if not return_retrieved_obs else (self.last_discrete_action, all_retrieved_images)
|
749 |
+
|
750 |
+
|
751 |
+
JatRegentModel.register_for_auto_class("AutoModelForCausalLM")
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f64fcecd190c7ed6c8c913e44d0ecc47ceaca4d52a84d1e96d18ebe985db8ef5
|
3 |
+
size 510060470
|