import torch from dateutil.parser import parse as parse_date from sklearn.model_selection import train_test_split from transformers import ( pipeline, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer ) from torch.utils.data import Dataset class GroundingDataset(Dataset): def __init__(self, data, tokenizer, max_length=512): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] encoding = self.tokenizer( item["question"], text_pair=item["answer"] + " [SEP] " + item["context"], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt" ) return { "input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(item["label"]) } class GroundingTrainer: def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") self.model = AutoModelForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=2 ) def train(self, dataset): train_data, val_data = train_test_split(dataset, test_size=0.2) trainer = Trainer( model=self.model, args=TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=8, evaluation_strategy="epoch", logging_dir="./logs" ), train_dataset=GroundingDataset(train_data, self.tokenizer), eval_dataset=GroundingDataset(val_data, self.tokenizer) ) trainer.train() self.model.save_pretrained("./grounding_detector") self.tokenizer.save_pretrained("./grounding_detector")