File size: 5,655 Bytes
95c280d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
from transformers import AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer
import torch
from tabulate import tabulate
import wandb
tokenizer = AutoTokenizer.from_pretrained("jjzha/jobbert_knowledge_extraction")
model = BertForTokenClassification.from_pretrained("Robzy/jobbert_knowledge_extraction")
artifact = wandb.Artifact(name="jobbert-knowledge-extraction", type="BERT")
text = 'Experience with Unreal and/or Unity and/or native IOS/Android 3D development and/or Web based 3D engines '
# Tokenize
inputs = tokenizer(
text, add_special_tokens=False, return_tensors="pt"
)
# Inference
# with torch.no_grad():
# output = model(**inputs)
# # Post-process
# predicted_token_class_ids = output.logits.argmax(-1)
# predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
# tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
# # Display
# table = zip(tokens, predicted_tokens_classes)
# print(tabulate(table, headers=["Token", "Predicted Class"], tablefmt="pretty"))
# Training
from datasets import load_dataset
dataset = load_dataset("json", data_files="data/test-short.json")
# Convert tokens to ids before training
data = [torch.tensor([tokenizer.convert_tokens_to_ids(t) for t in l]) for l in dataset['train']['tokens']]
dataset = dataset.map(
lambda x: {"input_ids": torch.tensor(tokenizer.convert_tokens_to_ids(x["tokens"]))}
)
# Data preprocessing
from torch.utils.data import DataLoader
import torch.nn as nn
from transformers import DataCollatorForTokenClassification
from typing import List, Tuple
def pad(list_of_lists, pad_value=0):
max_len = max(len(lst) for lst in list_of_lists)
# Pad shorter lists with the specified value
padded_lists = [lst + [pad_value] * (max_len - len(lst)) for lst in list_of_lists]
attention_masks = [[1] * len(lst) + [0] * (max_len - len(lst)) for lst in list_of_lists]
return torch.tensor(padded_lists), torch.tensor(attention_masks)
def collate_fn(batch: List[List[torch.Tensor]]):
input_ids, attention_mask = pad(list(map(lambda x: tokenizer.convert_tokens_to_ids(x['tokens']),batch)))
tags_knowledge, _ = pad([list(map(lambda x: label2id[x],o)) for o in [b['tags_knowledge'] for b in batch]])
return {"input_ids": input_ids, "tags_knowledge": tags_knowledge, "attention_mask": attention_mask}
# Training settings
batch_size = 32
train_dataloader = DataLoader(dataset['train'], shuffle=True, batch_size=batch_size, collate_fn=collate_fn)
eval_dataloader = DataLoader(dataset['train'], batch_size=batch_size, collate_fn=collate_fn)
from tqdm.auto import tqdm
from torch.optim import AdamW
from transformers import get_scheduler
model.train()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
IGNORE_INDEX = -100
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
id2label = model.config.id2label
label2id = model.config.label2id
lr = 5e-5
optimizer = AdamW(model.parameters(), lr=lr)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
model.config.pad_token_id = 0
## Training
from dotenv import load_dotenv
import os
load_dotenv(".env")
from datetime import datetime
current_time = datetime.now()
wandb.login(key=os.getenv('WANDB_API_KEY'))
run = wandb.init(
# set the wandb project where this run will be logged
project="in-demand",
# track hyperparameters and run metadata
config={
"learning_rate": lr,
"architecture": "BERT",
"epochs": num_epochs,
"batch_size": batch_size,
"notes": "Datetime: " + current_time.strftime("%m/%d/%Y, %H:%M:%S")
}
)
import logging
from datetime import datetime
logging.info("Initiating training")
progress_bar = tqdm(range(num_epochs), desc="Epochs")
for epoch in range(num_epochs):
logging.info(f"Epoch #{epoch}")
print(f"Epoch #{epoch}")
batch_count = 0
for batch in train_dataloader:
logging.info(f"Batch #{batch_count} / {len(train_dataloader)}")
print(f"Batch #{batch_count} / {len(train_dataloader)}")
tokens = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
tags_knowledge = batch['tags_knowledge'].to(device)
outputs = model(tokens, attention_mask=attention_mask)
# Batch
pred = outputs.logits.reshape(-1, model.config.num_labels) # Logits
label = torch.where(attention_mask==0, torch.tensor(IGNORE_INDEX).to(device), tags_knowledge).reshape(-1) # Labels, padding set to class idx -100
# Compute accuracy ignoring padding idx
_, predicted_labels = torch.max(pred, dim=1)
non_pad_elements = label != IGNORE_INDEX
correct_predictions = (predicted_labels[non_pad_elements] == label[non_pad_elements]).sum().item()
total_predictions = non_pad_elements.sum().item()
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
loss = criterion(pred, label)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
wandb.log({"epoch": epoch, "accuracy": accuracy, "loss": loss})
batch_count += 1
progress_bar.update(1)
model.push_to_hub("Robzy/jobbert_knowledge_extraction")
# Add the state_dict to the artifact
state_dict = model.state_dict()
with artifact.new_file('model.pth', mode='wb') as f:
torch.save(state_dict, f)
# Log the artifact to W&B
wandb.log_artifact(artifact) |