|
--- |
|
datasets: |
|
- facebook/anli |
|
metrics: |
|
- accuracy |
|
base_model: |
|
- meta-llama/Llama-3.1-8B-Instruct |
|
pipeline_tag: sentence-similarity |
|
library_name: peft |
|
tags: |
|
- NLI |
|
- Textual Entailment |
|
- Llama |
|
--- |
|
# Description |
|
|
|
The Meta Llama-3.1-8B-Instruct model fine-tuned on the Adversarial Natural Language Inference (ANLI) training set, for textual entailment tasks. |
|
|
|
**Evaluation Results** |
|
|
|
Accuracy on the test set (%): |
|
| ANLI-R1 | ANLI-R2 | ANLI-R3 | Avg. | |
|
| ------- | ------- |-------|-------| |
|
| 77.2 | 62.8 | 61.2 | 67.1 | |
|
|
|
## Usage |
|
|
|
NLI use-case: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers import AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
|
|
base_model_name = 'meta-llama/Llama-3.1-8B-Instruct' |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
model = AutoModelForCausalLM.from_pretrained(base_model_name, |
|
pad_token_id=tokenizer.eos_token_id, |
|
device_map='auto') |
|
|
|
lora_model = PeftModel.from_pretrained(model, 'cassuto/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct') |
|
|
|
label_str = ['entailment', 'neutral', 'contradiction'] |
|
|
|
def eval(premise : str, hypothesis : str, device = 'cuda'): |
|
input = ("<|start_header_id|>system<|end_header_id|>\n\nBased on the following premise, determine if the hypothesis is entailment, contradiction, or neutral." + |
|
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" |
|
"<Premise>: " + premise + "\n\n<Hypothesis>: " + hypothesis + "\n\n" + |
|
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>") |
|
tk = tokenizer(input) |
|
|
|
with torch.no_grad(): |
|
input_ids = torch.tensor(tk['input_ids']).unsqueeze(0).to(device) |
|
out = lora_model.generate(input_ids=input_ids, |
|
attention_mask=torch.tensor(tk['attention_mask']).unsqueeze(0).to(device), |
|
max_new_tokens=10 |
|
) |
|
print(tokenizer.decode(out[0])) |
|
s = tokenizer.decode(out[0][input_ids.shape[-1]:]) |
|
for lbl, l in enumerate(label_str): |
|
if s.find(l) > -1: |
|
return lbl |
|
else: |
|
assert False, 'Invalid model output: ' + s |
|
|
|
print(eval("A man is playing a guitar.", "A woman is reading a book.")) |
|
``` |
|
|
|
## Training Details |
|
|
|
- **Dataset:** facebook/anli |
|
- **Hardware:** NIVIDA H20 (96GB) card x1. |
|
|
|
#### Fine Tuning Hyperparameters |
|
|
|
- **Training regime:** fp16 <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision --> |
|
- **LoRA rank:** 64 |
|
- **LoRA alpha:** 16 |
|
- **LoRA dropout:** 0.1 |
|
- **Learning rate:** 0.0001 |
|
- **Training Batch Size:** 4 |
|
- **Epoch:** 3 |
|
- **Context length:** 2048 |
|
|
|
#### Fine Tuning Code |
|
|
|
```python |
|
from datasets import load_dataset |
|
import numpy as np |
|
|
|
dataset = load_dataset("anli") |
|
|
|
model_name = "meta-llama/Llama-3.1-8B-Instruct" |
|
def out_ckp(r): |
|
return f"/path/to/project/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct/checkpoints-r{r}" |
|
def out_lora_model_fn(r): |
|
return f'/path/to/project/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct/lora-r{r}' |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from transformers import AutoTokenizer, GenerationConfig |
|
from peft import LoraConfig, get_peft_model, PeftModel |
|
from trl import SFTConfig, SFTTrainer |
|
import torch |
|
from collections.abc import Mapping |
|
|
|
label_str = ['entailment', 'neutral', 'contradiction'] |
|
|
|
def preprocess_function(examples): |
|
inputs = ["<|start_header_id|>system<|end_header_id|>\n\nBased on the following premise, determine if the hypothesis is entailment, contradiction, or neutral." + |
|
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" |
|
"<Premise>: " + p + "\n\n<Hypothesis>: " + h + "\n\n" + |
|
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + label_str[lbl] + "<|eot_id|>\n" # FIXME remove \n |
|
for p, h, lbl in zip(examples["premise"], examples["hypothesis"], examples['label'])] |
|
|
|
model_inputs = {} |
|
model_inputs['text'] = inputs |
|
return model_inputs |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, |
|
pad_token_id=tokenizer.eos_token, |
|
device_map='auto') |
|
model.config.use_cache=False |
|
model.config.pretraining_tp=1 |
|
|
|
tokenizer.padding_side = "right" |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
for r in range(1,4): |
|
print('Round ', r) |
|
|
|
train_data = dataset[f'train_r{r}'] |
|
val_data = dataset[f'dev_r{r}'] |
|
train_data = train_data.map(preprocess_function, batched=True,num_proc=8) |
|
val_data = val_data.map(preprocess_function, batched=True,num_proc=8) |
|
|
|
training_args = SFTConfig( |
|
fp16=True, |
|
output_dir=out_ckp(r), |
|
learning_rate=1e-4, |
|
per_device_train_batch_size=4, |
|
per_device_eval_batch_size=1, |
|
num_train_epochs=3, |
|
logging_steps=10, |
|
weight_decay=0, |
|
logging_dir=f"./logs-r{r}", |
|
save_strategy="epoch", |
|
save_total_limit=1, |
|
max_seq_length=2048, |
|
packing=False, |
|
dataset_text_field="text" |
|
) |
|
|
|
if r==1: |
|
# create LoRA model |
|
peft_config = LoraConfig( |
|
r=64, |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
bias="none", |
|
task_type='CAUSAL_LM' |
|
) |
|
lora_model = get_peft_model(model, peft_config) |
|
else: |
|
# load the previous trained LoRA part |
|
lora_model = PeftModel.from_pretrained(model, out_lora_model_fn(r-1), |
|
is_trainable=True) |
|
|
|
trainer = SFTTrainer( |
|
model=lora_model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
train_dataset=train_data, |
|
) |
|
|
|
trainer.train() |
|
print(f'saving to "{out_lora_model_fn(r)}"') |
|
lora_model.save_pretrained(out_lora_model_fn(r)) |
|
``` |
|
|
|
|
|
- PEFT 0.13.2 |