Snake-V1

This model is trained on the Jumanji snake environment

Developed by: InstaDeep

Model Sources

How to use

Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.

pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.

Below is an example script for loading and running the Jumanji model

import pickle
import joblib

import jax
from hydra import compose, initialize
from huggingface_hub import hf_hub_download


from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device

# initialise the config
with initialize(version_base=None, config_path="jumanji/training/configs"):
    cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])

# get model state from HF
REPO_ID = "d-byrne/snake-v1_training_state"
FILENAME = "Snake-v1_training_state"

model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

with open(model_weights,"rb") as f:
    training_state = pickle.load(f)

params = first_from_device(training_state.params_state.params)
env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))

# rollout a few episodes
NUM_EPISODES = 10

states = []
key = jax.random.PRNGKey(cfg.seed)
for episode in range(NUM_EPISODES):
    key, reset_key = jax.random.split(key) 
    state, timestep = jax.jit(env.reset)(reset_key)
    while not timestep.last():
        key, action_key = jax.random.split(key)
        observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
        action, _ = policy(observation, action_key)
        state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
        states.append(state)
    # Freeze the terminal frame to pause the GIF.
    for _ in range(10):
        states.append(state)

# animate a GIF
env.animate(states, interval=150).save("./snake.gif")

# save PNG
import matplotlib.pyplot as plt
%matplotlib inline
env.render(states[117])
plt.savefig("connector.png", dpi=300)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.