EasyDeL

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for training Flax/Jax models on TPU/GPU for both serving and training purposes.

Using Example

Using From EasyDeLState (*.easy files)

from easydel import EasyDeLState, AutoShardAndGatherFunctions
from jax import numpy as jnp, lax

shard_fns, gather_fns = AutoShardAndGatherFunctions.from_pretrained(
    "REPO_ID", # Pytorch State should be saved to in order to find shard gather fns with no effort, otherwise read docs.
    backend="gpu",
    depth_target=["params", "params"],
    flatten=False
)

state = EasyDeLState.load_state(
    "*.easy",
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    precision=lax.Precision("fastest"),
    verbose=True,
    state_shard_fns=shard_fns
)
# State file Ready to use ...

Using From AutoEasyDeLModelForCausalLM (from PyTorch)

from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax


model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    "REPO_ID",
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    precision=lax.Precision("fastest"),
    auto_shard_params=True,
)
# Model and Parameters Ready to use ...

Using From AutoEasyDeLModelForCausalLM (from EasyDeL)

from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax


model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    "REPO_ID/",
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    precision=lax.Precision("fastest"),
    auto_shard_params=True,
    from_torch=False
)
# Model and Parameters Ready to use ...
Downloads last month
154
Safetensors
Model size
9.08B params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.