File size: 2,250 Bytes
f37a8fc 51dac85 294a607 02b850c f37a8fc 02b850c f37a8fc 02b850c f37a8fc 294a607 f37a8fc 294a607 51dac85 f37a8fc 51dac85 294a607 f37a8fc 294a607 f37a8fc 294a607 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh
index dbf328a..c393191 100755
--- a/data/envs/metaworld/train_all.sh
+++ b/data/envs/metaworld/train_all.sh
@@ -4,7 +4,7 @@ ENVS=(
assembly
basketball
bin-picking
- box-close
+ #box-close
button-press-topdown
button-press-topdown-wall
button-press
diff --git a/gia/eval/callback.py b/gia/eval/callback.py
index 5c3a080..4b6198f 100644
--- a/gia/eval/callback.py
+++ b/gia/eval/callback.py
@@ -2,10 +2,10 @@ import glob
import json
import subprocess
-import wandb
from accelerate import Accelerator
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
+import wandb
from gia.config import Arguments
from gia.eval.utils import is_slurm_available
diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
index ec5e5b2..3294471 100644
--- a/gia/eval/rl/envs/core.py
+++ b/gia/eval/rl/envs/core.py
@@ -180,7 +180,7 @@ def make(task_name: str, num_envs: int = 1):
import metaworld
env_id = TASK_TO_ENV_MAPPING[task_name]
- env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
+ env = gym.make(env_id)
else:
raise ValueError(f"Unknown task name: {task_name}")
diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
index f0d0b9b..255beda 100644
--- a/gia/eval/rl/gia_agent.py
+++ b/gia/eval/rl/gia_agent.py
@@ -54,7 +54,7 @@ class GiaAgent:
self.action_space = action_space
self.deterministic = deterministic
self.device = next(model.parameters()).device
- self._max_length = self.model.config.max_position_embeddings - 10
+ self._max_length = self.model.config.max_position_embeddings - 100
if isinstance(observation_space, spaces.Box):
self._observation_key = "continuous_observations"
diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py
index f8531ee..71e0fdc 100644
--- a/gia/eval/rl/gym_evaluator.py
+++ b/gia/eval/rl/gym_evaluator.py
@@ -1,7 +1,6 @@
import gym
from gym.vector.vector_env import VectorEnv
-from gia.eval.mappings import TASK_TO_ENV_MAPPING
from gia.eval.rl.rl_evaluator import RLEvaluator
|