erfanzar commited on
Commit
8d139e1
·
verified ·
1 Parent(s): 445874a

Upload FlaxLlamaForCausalLM

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +79 -0
  3. config.json +100 -0
  4. easydel-model.parameters +3 -0
  5. generation_config.json +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ easydel-model.parameters filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - EasyDeL
4
+ - FlaxLlamaForCausalLM
5
+ - safetensors
6
+ - TPU
7
+ - GPU
8
+ - XLA
9
+ - Flax
10
+ ---
11
+ # EasyDeL/EasyDeL-Llama-3.1-8B-Instruct
12
+
13
+ [![EasyDeL](https://img.shields.io/badge/🤗_EasyDeL-0.0.80-blue.svg)](https://github.com/erfanzar/EasyDeL)
14
+ [![Model Type](https://img.shields.io/badge/Model_Type-FlaxLlamaForCausalLM-green.svg)](https://github.com/erfanzar/EasyDeL)
15
+
16
+ A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks.
17
+
18
+ ## Overview
19
+
20
+ EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels.
21
+
22
+ ## Features
23
+
24
+
25
+ - **Efficient Implementation**: Built with JAX/Flax for high-performance computation.
26
+ - **Multi-Device Support**: Optimized to run on TPU, GPU, and CPU environments for sharding model over 2^(1-1000+) of devices.
27
+ - **Sharded Model Parallelism**: Supports model parallelism across multiple devices for scalability.
28
+ - **Customizable Precision**: Allows specification of floating-point precision for performance optimization.
29
+
30
+
31
+ ## Installation
32
+
33
+ To install EasyDeL, simply run:
34
+
35
+ ```bash
36
+ pip install easydel
37
+ ```
38
+
39
+ ## Usage
40
+
41
+ ### Loading the Pre-trained Model
42
+
43
+ To load a pre-trained version of the model with EasyDeL:
44
+
45
+ ```python
46
+ from easydel import AutoEasyDeLModelForCausalLM
47
+ from jax import numpy as jnp, lax
48
+
49
+ max_length = None # can be set to use lower memory for caching
50
+
51
+ # Load model and parameters
52
+ model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
53
+ "EasyDeL/EasyDeL-Llama-3.1-8B-Instruct",
54
+ config_kwargs=dict(
55
+ use_scan_mlp=False,
56
+ attn_dtype=jnp.float16,
57
+ freq_max_position_embeddings=max_length,
58
+ mask_max_position_embeddings=max_length,
59
+ attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2
60
+ ),
61
+ dtype=jnp.float16,
62
+ param_dtype=jnp.float16,
63
+ precision=lax.Precision("fastest"),
64
+ auto_shard_params=True,
65
+ )
66
+ ```
67
+
68
+ ## Supported Tasks
69
+
70
+
71
+ [Need more information]
72
+
73
+
74
+ ## Limitations
75
+
76
+
77
+ - **Hardware Dependency**: Performance can vary significantly based on the hardware used.
78
+ - **JAX/Flax Setup Required**: The environment must support JAX/Flax for optimal use.
79
+ - **Experimental Features**: Some features (like custom kernel usage or ed-ops) may require additional configuration and tuning.
config.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_axis_name": "sp",
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "attn_mechanism": "vanilla",
9
+ "axis_dims": [
10
+ 1,
11
+ 1,
12
+ 1,
13
+ -1
14
+ ],
15
+ "axis_names": [
16
+ "dp",
17
+ "fsdp",
18
+ "tp",
19
+ "sp"
20
+ ],
21
+ "backend": null,
22
+ "bits": null,
23
+ "block_b": 1,
24
+ "block_k": 128,
25
+ "block_q": 128,
26
+ "bos_token_id": 128000,
27
+ "easy_method": "train",
28
+ "embd_pdrop": 0.0,
29
+ "eos_token_id": [
30
+ 128001,
31
+ 128008,
32
+ 128009
33
+ ],
34
+ "fcm_max_ratio": 0.0,
35
+ "fcm_min_ratio": 0.0,
36
+ "flash_attention_backward_pass_impl": "triton",
37
+ "freq_max_position_embeddings": 32768,
38
+ "gradient_checkpointing": "nothing_saveable",
39
+ "hardware_abstraction": false,
40
+ "head_dim": 128,
41
+ "hidden_act": "silu",
42
+ "hidden_size": 4096,
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 14336,
45
+ "mask_max_position_embeddings": 32768,
46
+ "max_position_embeddings": 131072,
47
+ "mlp_bias": false,
48
+ "model_type": "llama",
49
+ "num_attention_heads": 32,
50
+ "num_hidden_layers": 32,
51
+ "num_key_value_heads": 8,
52
+ "number_rep_kv": 1,
53
+ "pallas_k_block_size": null,
54
+ "pallas_m_block_size": null,
55
+ "pallas_n_block_size": null,
56
+ "partition_axis": [
57
+ [
58
+ "fsdp",
59
+ "dp"
60
+ ],
61
+ "sp",
62
+ "sp",
63
+ "tp",
64
+ "sp",
65
+ "tp",
66
+ null,
67
+ null,
68
+ null,
69
+ null,
70
+ "tp",
71
+ "sp",
72
+ null
73
+ ],
74
+ "platform": null,
75
+ "pretraining_tp": 1,
76
+ "quantize_kv_cache": false,
77
+ "resid_pdrop": 0.0,
78
+ "rms_norm_eps": 1e-05,
79
+ "rope_scaling": {
80
+ "factor": 8.0,
81
+ "high_freq_factor": 4.0,
82
+ "low_freq_factor": 1.0,
83
+ "original_max_position_embeddings": 8192,
84
+ "rope_type": "llama3"
85
+ },
86
+ "rope_theta": 10000.0,
87
+ "scan_attention_layers": false,
88
+ "scan_layers": true,
89
+ "scan_mlp_chunk_size": 1024,
90
+ "scan_ring_attention": true,
91
+ "shard_attention_computation": true,
92
+ "tie_word_embeddings": false,
93
+ "torch_dtype": "bfloat16",
94
+ "transformers_version": "4.46.2",
95
+ "use_cache": true,
96
+ "use_scan_mlp": false,
97
+ "use_sharded_kv_caching": true,
98
+ "use_sharding_constraint": false,
99
+ "vocab_size": 128256
100
+ }
easydel-model.parameters ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f7c3b1eb49bf5c2f8e9528f9cd8e35bdfc6fbcd8cda0a8722f99950d32ae76c
3
+ size 16060556584
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 128000,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 128001,
6
+ 128008,
7
+ 128009
8
+ ],
9
+ "temperature": 0.6,
10
+ "top_p": 0.9,
11
+ "transformers_version": "4.46.2"
12
+ }