BitNEt 250 M trained on 7B tokens on PubMed + Clinical dataset
Inference code:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import *
# Load a pretrained BitNet model
model = "Siddharth63/Bitnet-250M"
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model)
def activation_quant(x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127)
y = y / scale
return y
def weight_quant(w):
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1)
u = u / scale
return u
class BitLinear(nn.Linear):
def forward(self, x):
w = self.weight # a weight tensor with shape [d, k]
x = x.to(w.device)
RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
x_norm = RMSNorm(x)
# A trick for implementing Straight−Through−Estimator (STE) using detach()
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
def convert_to_bitnet(model, copy_weights):
for name, module in model.named_modules():
# Replace linear layers with BitNet
if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
for child_name, child_module in module.named_children():
if isinstance(child_module, nn.Linear):
bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
if copy_weights:
bitlinear.weight = child_module.weight
if child_module.bias is not None:
bitlinear.bias = child_module.bias
setattr(module, child_name, bitlinear)
# Remove redundant input_layernorms
elif isinstance(module, LlamaDecoderLayer):
for child_name, child_module in module.named_children():
if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
setattr(module, child_name, nn.Identity().to(device="cuda:0"))
convert_to_bitnet(model, copy_weights=True)
model.to(device="cuda:0")
prompt = "Atherosclerosis is"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generate_ids = model.generate(inputs.input_ids, max_length=50)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- Downloads last month
- 215
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.