cassuto's picture
Update README.md
7d6bf6f verified
|
raw
history blame
5.74 kB
metadata
datasets:
  - facebook/anli
metrics:
  - accuracy
base_model:
  - meta-llama/Llama-3.1-8B-Instruct
pipeline_tag: sentence-similarity
library_name: peft
tags:
  - NLI

Model Card for Model ID

The Meta Llama-3.1-8B-Instruct model fine-tuned on the Adversarial Natural Language Inference (ANLI) Benchmark.

Evaluation Results

Accuracy:

ANLI-R1 ANLI-R2 ANLI-R3 Avg.
77.2 62.8 61.2 67.1

Usage

NLI use-case:

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
from peft import PeftModel
import torch

base_model_name = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name,
                pad_token_id=tokenizer.eos_token_id,
                device_map='auto')

lora_model = PeftModel.from_pretrained(model, 'cassuto/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct')

label_str = ['entailment', 'neutral', 'contradiction']

def eval(premise : str, hypothesis : str, device = 'cuda'):
    input = ("<|start_header_id|>system<|end_header_id|>\n\nBased on the following premise, determine if the hypothesis is entailment, contradiction, or neutral." +
              "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
                "<Premise>: " + premise + "\n\n<Hypothesis>: " + hypothesis + "\n\n" +
                "<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
    tk = tokenizer(input)

    with torch.no_grad(): 
        input_ids = torch.tensor(tk['input_ids']).unsqueeze(0).to(device)
        out = lora_model.generate(input_ids=input_ids,
            attention_mask=torch.tensor(tk['attention_mask']).unsqueeze(0).to(device),
            max_new_tokens=10
        )
        print(tokenizer.decode(out[0]))
        s = tokenizer.decode(out[0][input_ids.shape[-1]:])
        for lbl, l in enumerate(label_str):
            if s.find(l) > -1:
                return lbl
        else:
            assert False, 'Invalid model output: ' + s

print(eval("A man is playing a guitar.", "A woman is reading a book."))

Training Details

  • Dataset: facebook/anli
  • Hardware: NIVIDA H20 (96GB) card x1.

Fine Tuning Hyperparameters

  • Training regime: fp16
  • LoRA rank: 64
  • LoRA alpha: 16
  • LoRA dropout: 0.1
  • Learning rate: 0.0001
  • Training Batch Size: 4
  • Epoch: 3
  • Context length: 2048

Fine Tuning Code

from datasets import load_dataset
import numpy as np

dataset = load_dataset("anli")

model_name = "meta-llama/Llama-3.1-8B-Instruct"
def out_ckp(r):
    return f"/path/to/project/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct/checkpoints-r{r}"
def out_lora_model_fn(r):
    return f'/path/to/project/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct/lora-r{r}'

from transformers import AutoModelForCausalLM, AutoTokenizer

from transformers import AutoTokenizer, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTConfig, SFTTrainer
import torch
from collections.abc import Mapping

label_str = ['entailment', 'neutral', 'contradiction']

def preprocess_function(examples):
    inputs = ["<|start_header_id|>system<|end_header_id|>\n\nBased on the following premise, determine if the hypothesis is entailment, contradiction, or neutral." +
              "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
                "<Premise>: " + p + "\n\n<Hypothesis>: " + h + "\n\n" +
                "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + label_str[lbl] + "<|eot_id|>\n" # FIXME remove \n
                for p, h, lbl in zip(examples["premise"], examples["hypothesis"], examples['label'])]

    model_inputs = {}
    model_inputs['text'] = inputs
    return model_inputs

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name,
                pad_token_id=tokenizer.eos_token,
                device_map='auto')
model.config.use_cache=False
model.config.pretraining_tp=1

tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

for r in range(1,4):
    print('Round ', r)

    train_data = dataset[f'train_r{r}']
    val_data = dataset[f'dev_r{r}']
    train_data = train_data.map(preprocess_function, batched=True,num_proc=8)
    val_data = val_data.map(preprocess_function, batched=True,num_proc=8)

    training_args = SFTConfig(
        fp16=True,
        output_dir=out_ckp(r),
        learning_rate=1e-4,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=1,
        num_train_epochs=3,
        logging_steps=10,
        weight_decay=0,
        logging_dir=f"./logs-r{r}",
        save_strategy="epoch",
        save_total_limit=1,
        max_seq_length=2048,
        packing=False,
        dataset_text_field="text"
    )

    if r==1:
        # create LoRA model
        peft_config = LoraConfig(
            r=64,
            lora_alpha=16,
            lora_dropout=0.1,
            bias="none",
            task_type='CAUSAL_LM'
        )
        lora_model = get_peft_model(model, peft_config)
    else:
        # load the previous trained LoRA part
        lora_model = PeftModel.from_pretrained(model, out_lora_model_fn(r-1),
            is_trainable=True)

    trainer = SFTTrainer(
        model=lora_model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_data,
    )

    trainer.train()
    print(f'saving to "{out_lora_model_fn(r)}"')
    lora_model.save_pretrained(out_lora_model_fn(r))
  • PEFT 0.13.2