Spaces:
Sleeping
Sleeping
File size: 3,065 Bytes
469c254 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from typing import Dict, List, Union
import pandas as pd
import numpy as np
class FakeNewsDataset(Dataset):
def __init__(self,
texts: List[str],
labels: List[int],
tokenizer: BertTokenizer,
max_length: int = 512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
text = str(self.texts[idx])
label = self.labels[idx]
encoding = self.tokenizer(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def create_data_loaders(
df: pd.DataFrame,
text_column: str,
label_column: str,
tokenizer: BertTokenizer,
batch_size: int = 32,
max_length: int = 512,
train_size: float = 0.8,
val_size: float = 0.1,
random_state: int = 42
) -> Dict[str, torch.utils.data.DataLoader]:
"""Create train, validation, and test data loaders."""
# Split data
train_df = df.sample(frac=train_size, random_state=random_state)
remaining_df = df.drop(train_df.index)
val_df = remaining_df.sample(frac=val_size/(1-train_size), random_state=random_state)
test_df = remaining_df.drop(val_df.index)
# Create datasets
train_dataset = FakeNewsDataset(
texts=train_df[text_column].tolist(),
labels=train_df[label_column].tolist(),
tokenizer=tokenizer,
max_length=max_length
)
val_dataset = FakeNewsDataset(
texts=val_df[text_column].tolist(),
labels=val_df[label_column].tolist(),
tokenizer=tokenizer,
max_length=max_length
)
test_dataset = FakeNewsDataset(
texts=test_df[text_column].tolist(),
labels=test_df[label_column].tolist(),
tokenizer=tokenizer,
max_length=max_length
)
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4
)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
} |