updated adfe
Browse files- train_llama.py +9 -4
train_llama.py
CHANGED
@@ -24,7 +24,7 @@ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
24 |
# Load model with FlashAttention 2
|
25 |
model = LlamaForCausalLM.from_pretrained(
|
26 |
MODEL_ID,
|
27 |
-
torch_dtype=torch.bfloat16,
|
28 |
device_map="auto",
|
29 |
quantization_config=quantization_config,
|
30 |
attn_implementation="flash_attention_2"
|
@@ -43,7 +43,7 @@ model.print_trainable_parameters()
|
|
43 |
dataset = datasets.load_dataset("json", data_files="final_combined_fraud_data.json", field="training_pairs")
|
44 |
print("First example from dataset:", dataset["train"][0])
|
45 |
|
46 |
-
# Tokenization
|
47 |
def tokenize_data(example):
|
48 |
formatted_text = f"{example['input']} {example['output']}"
|
49 |
inputs = tokenizer(formatted_text, truncation=True, max_length=2048, return_tensors="pt")
|
@@ -51,10 +51,15 @@ def tokenize_data(example):
|
|
51 |
labels = inputs["input_ids"].clone().squeeze(0)
|
52 |
input_len = len(tokenizer(example['input'])["input_ids"])
|
53 |
labels[:input_len] = -100
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
tokenized_dataset = dataset["train"].map(tokenize_data, batched=False, remove_columns=dataset["train"].column_names)
|
57 |
-
print("First tokenized example:", {k: (type(v), v.shape
|
58 |
|
59 |
# Data collator
|
60 |
def custom_data_collator(features):
|
|
|
24 |
# Load model with FlashAttention 2
|
25 |
model = LlamaForCausalLM.from_pretrained(
|
26 |
MODEL_ID,
|
27 |
+
torch_dtype=torch.bfloat16,
|
28 |
device_map="auto",
|
29 |
quantization_config=quantization_config,
|
30 |
attn_implementation="flash_attention_2"
|
|
|
43 |
dataset = datasets.load_dataset("json", data_files="final_combined_fraud_data.json", field="training_pairs")
|
44 |
print("First example from dataset:", dataset["train"][0])
|
45 |
|
46 |
+
# Tokenization with tensors
|
47 |
def tokenize_data(example):
|
48 |
formatted_text = f"{example['input']} {example['output']}"
|
49 |
inputs = tokenizer(formatted_text, truncation=True, max_length=2048, return_tensors="pt")
|
|
|
51 |
labels = inputs["input_ids"].clone().squeeze(0)
|
52 |
input_len = len(tokenizer(example['input'])["input_ids"])
|
53 |
labels[:input_len] = -100
|
54 |
+
attention_mask = inputs["attention_mask"].squeeze(0)
|
55 |
+
return {
|
56 |
+
"input_ids": input_ids,
|
57 |
+
"labels": labels,
|
58 |
+
"attention_mask": attention_mask
|
59 |
+
}
|
60 |
|
61 |
tokenized_dataset = dataset["train"].map(tokenize_data, batched=False, remove_columns=dataset["train"].column_names)
|
62 |
+
print("First tokenized example:", {k: (type(v), v.shape) for k, v in tokenized_dataset[0].items()})
|
63 |
|
64 |
# Data collator
|
65 |
def custom_data_collator(features):
|