File size: 5,672 Bytes
95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d c87a61e 95c280d 762e05d 92fbfc8 762e05d c87a61e 762e05d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from transformers import AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer
import torch
from tabulate import tabulate
import wandb
import os
import yaml
from datetime import datetime
def train(json_path: str):
### Model & tokenizer loading
tokenizer = AutoTokenizer.from_pretrained("jjzha/jobbert_knowledge_extraction")
model = BertForTokenClassification.from_pretrained("Robzy/jobbert_knowledge_extraction")
with open("./config.yaml", "r") as file:
config = yaml.safe_load(file)
num_epochs = config['training']['epochs']
batch_size = config['training']['batch_size']
lr = config['training']['learning_rate']
current_time = datetime.now()
run = wandb.init(
# set the wandb project where this run will be logged
project="in-demand",
# track hyperparameters and run metadata
config={
"learning_rate": lr,
"architecture": "BERT",
"epochs": num_epochs,
"batch_size": batch_size,
"notes": "Datetime: " + current_time.strftime("%m/%d/%Y, %H:%M:%S")
}
)
### Data loading and preprocessing
from torch.utils.data import DataLoader
import torch.nn as nn
from transformers import DataCollatorForTokenClassification
from typing import List, Tuple
from datasets import load_dataset
# dataset = load_dataset("json", data_files="data/test-short.json")
dataset = load_dataset("json", data_files=json_path)
dataset = dataset.map(
lambda x: {"input_ids": torch.tensor(tokenizer.convert_tokens_to_ids(x["tokens"]))}
)
def pad(list_of_lists, pad_value=0):
max_len = max(len(lst) for lst in list_of_lists)
# Pad shorter lists with the specified value
padded_lists = [lst + [pad_value] * (max_len - len(lst)) for lst in list_of_lists]
attention_masks = [[1] * len(lst) + [0] * (max_len - len(lst)) for lst in list_of_lists]
return torch.tensor(padded_lists), torch.tensor(attention_masks)
def collate_fn(batch: List[List[torch.Tensor]]):
input_ids, attention_mask = pad(list(map(lambda x: tokenizer.convert_tokens_to_ids(x['tokens']),batch)))
tags_knowledge, _ = pad([list(map(lambda x: label2id[x],o)) for o in [b['tags_knowledge'] for b in batch]])
return {"input_ids": input_ids, "tags_knowledge": tags_knowledge, "attention_mask": attention_mask}
### Training settings
train_dataloader = DataLoader(dataset['train'], batch_size=batch_size, collate_fn=collate_fn)
from tqdm.auto import tqdm
from torch.optim import AdamW
from transformers import get_scheduler
model.train()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
IGNORE_INDEX = -100
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
id2label = model.config.id2label
label2id = model.config.label2id
optimizer = AdamW(model.parameters(), lr=lr)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
### Training
from dotenv import load_dotenv
import os
load_dotenv(".env")
import logging
logging.info("Initiating training")
progress_bar = tqdm(range(num_epochs), desc="Epochs")
for epoch in range(num_epochs):
logging.info(f"Epoch #{epoch}")
# print(f"Epoch #{epoch}")
batch_count = 1
for batch in train_dataloader:
logging.info(f"Batch #{batch_count} / {len(train_dataloader)}")
# print(f"Batch #{batch_count} / {len(train_dataloader)}")
tokens = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
tags_knowledge = batch['tags_knowledge'].to(device)
outputs = model(tokens, attention_mask=attention_mask)
# Batch
pred = outputs.logits.reshape(-1, model.config.num_labels) # Logits
label = torch.where(attention_mask==0, torch.tensor(IGNORE_INDEX).to(device), tags_knowledge).reshape(-1) # Labels, padding set to class idx -100
# Compute accuracy ignoring padding idx
_, predicted_labels = torch.max(pred, dim=1)
non_pad_elements = label != IGNORE_INDEX
correct_predictions = (predicted_labels[non_pad_elements] == label[non_pad_elements]).sum().item()
total_predictions = non_pad_elements.sum().item()
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
loss = criterion(pred, label)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
wandb.log({"epoch": epoch, "accuracy": accuracy, "loss": loss})
batch_count += 1
progress_bar.update(1)
print("Training complete")
### Pushing model
# Hugging Face
model.push_to_hub("Robzy/jobbert_knowledge_extraction")
# W&B
artifact = wandb.Artifact(name="jobbert-knowledge-extraction", type="BERT")
state_dict = model.state_dict()
with artifact.new_file('model.pth', mode='wb') as f:
torch.save(state_dict, f)
# Log the artifact to W&B
wandb.log_artifact(artifact)
def train_today():
date = datetime.today().strftime('%d-%m-%Y')
# date = "04-01-2025"
json_path = os.path.join(os.getcwd(),f'data/tags-{date}.jsonl')
print(f"Training on {json_path}")
train(json_path=json_path)
if __name__ == "__main__":
train_today() |