grok-1 / README.md
Jonathan1909's picture
Fix example usage
59685c3
|
raw
history blame
2.39 kB
metadata
license: apache-2.0
pipeline_tag: text-generation
library_name: transformers

Grok-1 (PyTorch Version)

This repository contains the model and weights of the torch version of Grok-1 open-weights model. You could find a complete example code of using the torch-version Grok-1 in ColossalAI GitHub Repository. We also applies parallelism techniques from ColossalAI framework (Tensor Parallelism for now) to accelerate the inference.

You could find the original weights released by xAI in Hugging Face and the original model in the Grok open release GitHub Repository.

Conversion

We translated the original modeling written in JAX into PyTorch version, and converted the weights by mapping tensor files with parameter keys, de-quantizing the tensors with corresponding packed scales, and save to checkpoint file with torch APIs.

The original tokenizer is supposed to be used (i.e. tokenizer.model in GitHub Repository) with the torch-version model.

Usage

import torch
from transformers import AutoModelForCausalLM
from sentencepiece import SentencePieceProcessor

torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
    "hpcaitech/grok-1",
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
sp = SentencePieceProcessor(model_file="tokenizer.model")

text = "Replace this with your text"
input_ids = sp.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
attention_mask = torch.ones_like(input_ids)
generate_kwargs = {}  # Add any additional args if you want
inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
    **generate_kwargs,
}
outputs = model.generate(**inputs)

You could also use the transformers-compatible version of the tokenizer Xenova/grok-1-tokenizer

from transformers import LlamaTokenizerFast

tokenizer = LlamaTokenizerFast.from_pretrained('Xenova/grok-1-tokenizer')
inputs = tokenizer('hello world')

Note: A multi-GPU machine is required to test the model with the example code (For now, a 8x80G multi-GPU machine is required).