ITCL commited on
Commit
59f6375
·
verified ·
1 Parent(s): 45ff9be

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +161 -0
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - bitnet
7
+ datasets:
8
+ - abideen/Cosmopedia-100k-pretrain
9
+ ---
10
+
11
+
12
+ # Bitnet-Nous-Llama3-225M 🚀
13
+
14
+ Este modelo es una variante optimizada del **Llama3** utilizando la arquitectura **BitNet**, lo que reduce los pesos a los valores `-1`, `0`, y `1` para mejorar la eficiencia en el cómputo sin perder precisión.
15
+
16
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66b0ba742cf20f2528a916bd/vtbKlK5l6yuj5uyJkAEgg.png)
17
+
18
+ ## Modelo Base 🦙
19
+
20
+ - **Modelo Original**: [Meta-Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
21
+ - **Parámetros Reducidos**: 225M
22
+
23
+ ## Arquitectura 🔧
24
+
25
+ El modelo transforma las capas lineales de Llama3 en capas **BitLinear**, aprovechando las siguientes técnicas de cuantización:
26
+
27
+ - **Cuantización de activaciones**: Escala a ±127
28
+ - **Cuantización de pesos**: Escala a ±1
29
+
30
+ ### Especificaciones Técnicas 📋
31
+
32
+ - **Dimensiones**: 768
33
+ - **Capas**: 6
34
+ - **Contexto**: 256 tokens
35
+ - **Tamaño intermedio**: 1024
36
+ - **Número de cabezas de atención**: 6
37
+
38
+
39
+ ## Dataset 📚
40
+
41
+ El modelo fue entrenado usando el dataset [Cosmopedia-100k-pretrain](https://huggingface.co/datasets/abideen/Cosmopedia-100k-pretrain), que contiene una variedad de datos de texto.
42
+
43
+ ## Entrenamiento ⚙️
44
+
45
+ El modelo fue entrenado con la siguiente configuración:
46
+
47
+ - **Lote**: 16
48
+ - **Tasa de aprendizaje**: 1.5e-4
49
+ - **Épocas**: 2
50
+ - **Acumulación de gradientes**: 2 pasos
51
+ - **Decaimiento de pesos**: 0.01
52
+ - **Precisión Mixta**: FP16
53
+
54
+ ### Monitoreo 📊
55
+
56
+ El proceso de entrenamiento fue monitoreado usando **Weights & Biases**.
57
+
58
+ ## Uso del Modelo 💻
59
+
60
+ Para usar este modelo, puedes cargarlo desde Hugging Face con el siguiente código:
61
+ ```python
62
+ from transformers import AutoModelForCausalLM, AutoTokenizer
63
+ from transformers.models.llama.modeling_llama import *
64
+ import torch
65
+ from torch import nn
66
+ import torch.nn.functional as F
67
+ import coloredlogs
68
+ import logging
69
+
70
+ from utils.utils import count_parameters
71
+
72
+ coloredlogs.install(level='INFO', fmt='%(asctime)s - %(levelname)s - %(message)s', logger=logging.getLogger())
73
+ logger = logging.getLogger(__name__)
74
+
75
+
76
+
77
+
78
+ HF_TOKEN = "tuclaveaqui"
79
+ #model = "ejbejaranos/Bitnet-Llama3-from8BM-now2B"
80
+ model = "ejbejaranos/Bitnet-Nous-Llama3-225M" ## Working
81
+
82
+ # Load a pretrained BitNet model
83
+ tokenizer = AutoTokenizer.from_pretrained(model)
84
+
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ model,
87
+ token=HF_TOKEN
88
+ )
89
+
90
+
91
+ def count_parameters(model):
92
+ # Calculate the number of parameters in billions
93
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 10**9
94
+ print(f"Model size: {num_params:.3f}B parameters")
95
+ return int(num_params)
96
+
97
+
98
+
99
+ # Establece el pad_token_id
100
+ model.config.pad_token_id = tokenizer.eos_token_id
101
+
102
+ def activation_quant(x):
103
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
104
+ y = (x * scale).round().clamp_(-128, 127)
105
+ y = y / scale
106
+ return y
107
+
108
+ def weight_quant(w):
109
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
110
+ u = (w * scale).round().clamp_(-1, 1)
111
+ u = u / scale
112
+ return u
113
+
114
+ class BitLinear(nn.Linear):
115
+ def forward(self, x):
116
+ w = self.weight # a weight tensor with shape [d, k]
117
+ x = x.to(w.device)
118
+ RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
119
+ x_norm = RMSNorm(x)
120
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
121
+ w_quant = w + (weight_quant(w) - w).detach()
122
+ y = F.linear(x_quant, w_quant)
123
+ return y
124
+
125
+ def convert_to_bitnet(model, copy_weights):
126
+ for name, module in model.named_modules():
127
+ if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
128
+ for child_name, child_module in module.named_children():
129
+ if isinstance(child_module, nn.Linear):
130
+ bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
131
+ if copy_weights:
132
+ bitlinear.weight = child_module.weight
133
+ if child_module.bias is not None:
134
+ bitlinear.bias = child_module.bias
135
+ setattr(module, child_name, bitlinear)
136
+ elif isinstance(module, LlamaDecoderLayer):
137
+ for child_name, child_module in module.named_children():
138
+ if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
139
+ setattr(module, child_name, nn.Identity().to(device="cuda:0"))
140
+
141
+ convert_to_bitnet(model, copy_weights=True)
142
+ model.to(device="cuda:0")
143
+
144
+
145
+ logger.info(f"🔢 Number of parameters in the model after extracting weights: {count_parameters(model)}")
146
+ logger.info(f"📏 Reduced model structure:\n{model}")
147
+
148
+
149
+
150
+
151
+
152
+ prompt = "What is Machine Learning?"
153
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
154
+ inputs['attention_mask'] = inputs['input_ids'] != model.config.pad_token_id
155
+
156
+ generate_ids = model.generate(inputs.input_ids, attention_mask=inputs['attention_mask'], max_length=250)
157
+ decoded_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
158
+
159
+ print(decoded_output[0]) # Print the generated response
160
+
161
+ ```