File size: 3,717 Bytes
54fa0c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import json
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset, DatasetDict
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
TrainingArguments,
Trainer,
DataCollatorForTokenClassification,
HfArgumentParser,
EarlyStoppingCallback
)
from utils import compute_metrics, label_tokenized, chunk_dataset, LABEL2ID, ID2LABEL
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
early_stopping_patience: int = field(
default=1,
)
early_stopping_threshold: int = field(
default=1e-3,
)
@dataclass
class DataTrainingArguments:
train_dataset: str = field(
default="bigcode/pseudo-labeled-python-data-pii-detection-filtered",
metadata={"help": "The train dataset"}
)
dev_dataset: str = field(
default="bigcode/pii-for-code-v2",
metadata={"help": "The validation dataset"}
)
max_seq_length: int = field(
default=512,
metadata={
"help": (
"The maximum input sequence length after tokenization. Sequences longer "
"than this will be chunked into pieces of this length."
)
},
)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
num_labels=len(ID2LABEL)
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
add_prefix_space=True)
model.config.id2label = {str(i): label for i, label in enumerate(ID2LABEL)}
model.config.label2id = LABEL2ID
tokenizer.model_max_length = data_args.max_seq_length
train_dataset = load_dataset(data_args.train_dataset, use_auth_token=True)['train']
dev_dataset = load_dataset(data_args.dev_dataset, use_auth_token=True)['train']
def tokenize_and_label(entry, tokenizer=tokenizer):
inputs = tokenizer.encode_plus(entry['content'], return_offsets_mapping=True, add_special_tokens=False)
entry.update(inputs)
return label_tokenized(entry)
dev_dataset = dev_dataset.map(lambda x: dict(pii=json.loads(x['pii'])))
dev_dataset = dev_dataset.map(tokenize_and_label)
train_dataset = train_dataset.map(lambda x: dict(pii=json.loads(x['pii'])))
train_dataset = train_dataset.map(tokenize_and_label, num_proc=8)
train_dataset = chunk_dataset(train_dataset, tokenizer)
ner_dataset = DatasetDict(
train=train_dataset,
validation=chunk_dataset(dev_dataset, tokenizer),
)
trainer = Trainer(
model,
training_args,
train_dataset=ner_dataset["train"],
eval_dataset=ner_dataset["validation"],
data_collator=DataCollatorForTokenClassification(tokenizer),
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience,
early_stopping_threshold=model_args.early_stopping_threshold)]
)
trainer.train()
|