cassuto's picture
Update README.md
490f34d verified
---
datasets:
- facebook/anli
metrics:
- accuracy
base_model:
- meta-llama/Llama-3.1-8B-Instruct
pipeline_tag: sentence-similarity
library_name: peft
tags:
- NLI
- Textual Entailment
- Llama
---
# Description
The Meta Llama-3.1-8B-Instruct model fine-tuned on the Adversarial Natural Language Inference (ANLI) training set, for textual entailment tasks.
**Evaluation Results**
Accuracy on the test set (%):
| ANLI-R1 | ANLI-R2 | ANLI-R3 | Avg. |
| ------- | ------- |-------|-------|
| 77.2 | 62.8 | 61.2 | 67.1 |
## Usage
NLI use-case:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
from peft import PeftModel
import torch
base_model_name = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name,
pad_token_id=tokenizer.eos_token_id,
device_map='auto')
lora_model = PeftModel.from_pretrained(model, 'cassuto/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct')
label_str = ['entailment', 'neutral', 'contradiction']
def eval(premise : str, hypothesis : str, device = 'cuda'):
input = ("<|start_header_id|>system<|end_header_id|>\n\nBased on the following premise, determine if the hypothesis is entailment, contradiction, or neutral." +
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
"<Premise>: " + premise + "\n\n<Hypothesis>: " + hypothesis + "\n\n" +
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
tk = tokenizer(input)
with torch.no_grad():
input_ids = torch.tensor(tk['input_ids']).unsqueeze(0).to(device)
out = lora_model.generate(input_ids=input_ids,
attention_mask=torch.tensor(tk['attention_mask']).unsqueeze(0).to(device),
max_new_tokens=10
)
print(tokenizer.decode(out[0]))
s = tokenizer.decode(out[0][input_ids.shape[-1]:])
for lbl, l in enumerate(label_str):
if s.find(l) > -1:
return lbl
else:
assert False, 'Invalid model output: ' + s
print(eval("A man is playing a guitar.", "A woman is reading a book."))
```
## Training Details
- **Dataset:** facebook/anli
- **Hardware:** NIVIDA H20 (96GB) card x1.
#### Fine Tuning Hyperparameters
- **Training regime:** fp16 <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
- **LoRA rank:** 64
- **LoRA alpha:** 16
- **LoRA dropout:** 0.1
- **Learning rate:** 0.0001
- **Training Batch Size:** 4
- **Epoch:** 3
- **Context length:** 2048
#### Fine Tuning Code
```python
from datasets import load_dataset
import numpy as np
dataset = load_dataset("anli")
model_name = "meta-llama/Llama-3.1-8B-Instruct"
def out_ckp(r):
return f"/path/to/project/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct/checkpoints-r{r}"
def out_lora_model_fn(r):
return f'/path/to/project/Llama-3.1-ANLI-R1-R2-R3-8B-Instruct/lora-r{r}'
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTConfig, SFTTrainer
import torch
from collections.abc import Mapping
label_str = ['entailment', 'neutral', 'contradiction']
def preprocess_function(examples):
inputs = ["<|start_header_id|>system<|end_header_id|>\n\nBased on the following premise, determine if the hypothesis is entailment, contradiction, or neutral." +
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
"<Premise>: " + p + "\n\n<Hypothesis>: " + h + "\n\n" +
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + label_str[lbl] + "<|eot_id|>\n" # FIXME remove \n
for p, h, lbl in zip(examples["premise"], examples["hypothesis"], examples['label'])]
model_inputs = {}
model_inputs['text'] = inputs
return model_inputs
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
pad_token_id=tokenizer.eos_token,
device_map='auto')
model.config.use_cache=False
model.config.pretraining_tp=1
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
for r in range(1,4):
print('Round ', r)
train_data = dataset[f'train_r{r}']
val_data = dataset[f'dev_r{r}']
train_data = train_data.map(preprocess_function, batched=True,num_proc=8)
val_data = val_data.map(preprocess_function, batched=True,num_proc=8)
training_args = SFTConfig(
fp16=True,
output_dir=out_ckp(r),
learning_rate=1e-4,
per_device_train_batch_size=4,
per_device_eval_batch_size=1,
num_train_epochs=3,
logging_steps=10,
weight_decay=0,
logging_dir=f"./logs-r{r}",
save_strategy="epoch",
save_total_limit=1,
max_seq_length=2048,
packing=False,
dataset_text_field="text"
)
if r==1:
# create LoRA model
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type='CAUSAL_LM'
)
lora_model = get_peft_model(model, peft_config)
else:
# load the previous trained LoRA part
lora_model = PeftModel.from_pretrained(model, out_lora_model_fn(r-1),
is_trainable=True)
trainer = SFTTrainer(
model=lora_model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data,
)
trainer.train()
print(f'saving to "{out_lora_model_fn(r)}"')
lora_model.save_pretrained(out_lora_model_fn(r))
```
- PEFT 0.13.2