import copy import random from dataclasses import dataclass, field from typing import Optional, Dict, Sequence import torch import torch.distributed import transformers import wandb from transformers import Trainer from datasets import load_dataset IGNORE_INDEX = -100 EOT_TOKEN = "<|EOT|>" def build_instruction_prompt(instruction: str): return ''' You are an AI programming assistant specializing in Solana blockchain development, utilizing the DeepSeek Coder model. You excel at writing, reviewing, and explaining Solana smart contracts, Rust programs, and web3 applications. For questions not related to Solana development or blockchain programming, you will politely decline to answer. ### Instruction: {} ### Response: '''.format(instruction.strip()).lstrip() @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="deepseek-ai/deepseek-coder-6.7b-instruct") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") model_max_length: int = field( default=2048, # Increased for longer Solana code contexts metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def preprocess( sources: Sequence[str], targets: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] input_ids = examples_tokenized["input_ids"] labels = copy.deepcopy(input_ids) for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): label[:source_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=labels) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = [torch.tensor(x) for x in input_ids] input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) labels = [torch.tensor(x) for x in labels] labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) return dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) def train_tokenize_function(examples, tokenizer): sources = [ build_instruction_prompt(instruction) for instruction in examples['instruction'] ] targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']] data_dict = preprocess(sources, targets, tokenizer) return data_dict class CustomTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.best_loss = float('inf') def compute_loss(self, model, inputs, return_outputs=False): loss, outputs = super().compute_loss(model, inputs, return_outputs=True) # Log metrics to wandb if self.state.global_step % self.args.logging_steps == 0: wandb.log({ "loss": loss.item(), "learning_rate": self.lr_scheduler.get_last_lr()[0], "epoch": self.state.epoch, "global_step": self.state.global_step }) # Save best model if loss.item() < self.best_loss: self.best_loss = loss.item() self.save_model("best_model") return (loss, outputs) if return_outputs else loss def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Initialize wandb if training_args.local_rank == 0: wandb.init( project="solana-deepseek-coder", config={ "model_name": model_args.model_name_or_path, "learning_rate": training_args.learning_rate, "batch_size": training_args.per_device_train_batch_size, "max_steps": training_args.max_steps, "warmup_steps": training_args.warmup_steps } ) print('='*100) print(training_args) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.model_max_length, padding_side="right", use_fast=True, trust_remote_code=True ) print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id) print("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id) print("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id) if training_args.local_rank == 0: print("Load tokenizer from {} over.".format(model_args.model_name_or_path)) model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto" # Enable automatic device mapping ) if training_args.local_rank == 0: print("Load model from {} over.".format(model_args.model_name_or_path)) raw_train_datasets = load_dataset( 'json', data_files=data_args.data_path, split="train", cache_dir=training_args.cache_dir ) if training_args.local_rank > 0: torch.distributed.barrier() train_dataset = raw_train_datasets.map( train_tokenize_function, batched=True, batch_size=3000, num_proc=32, remove_columns=raw_train_datasets.column_names, load_from_cache_file=True, desc="Running Encoding", fn_kwargs={ "tokenizer": tokenizer } ) if training_args.local_rank == 0: torch.distributed.barrier() if training_args.local_rank == 0: print("Training dataset samples:", len(train_dataset)) for index in random.sample(range(len(train_dataset)), 3): print(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.") print(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.") data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) trainer = CustomTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) trainer.train() trainer.save_state() safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if training_args.local_rank == 0: wandb.finish() if __name__ == "__main__": train()