Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset | |
from transformers import BertTokenizer | |
from typing import Dict, List, Union | |
import pandas as pd | |
import numpy as np | |
class FakeNewsDataset(Dataset): | |
def __init__(self, | |
texts: List[str], | |
labels: List[int], | |
tokenizer: BertTokenizer, | |
max_length: int = 512): | |
self.texts = texts | |
self.labels = labels | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self) -> int: | |
return len(self.texts) | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
text = str(self.texts[idx]) | |
label = self.labels[idx] | |
encoding = self.tokenizer( | |
text, | |
add_special_tokens=True, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
return { | |
'input_ids': encoding['input_ids'].flatten(), | |
'attention_mask': encoding['attention_mask'].flatten(), | |
'labels': torch.tensor(label, dtype=torch.long) | |
} | |
def create_data_loaders( | |
df: pd.DataFrame, | |
text_column: str, | |
label_column: str, | |
tokenizer: BertTokenizer, | |
batch_size: int = 32, | |
max_length: int = 512, | |
train_size: float = 0.8, | |
val_size: float = 0.1, | |
random_state: int = 42 | |
) -> Dict[str, torch.utils.data.DataLoader]: | |
"""Create train, validation, and test data loaders.""" | |
# Split data | |
train_df = df.sample(frac=train_size, random_state=random_state) | |
remaining_df = df.drop(train_df.index) | |
val_df = remaining_df.sample(frac=val_size/(1-train_size), random_state=random_state) | |
test_df = remaining_df.drop(val_df.index) | |
# Create datasets | |
train_dataset = FakeNewsDataset( | |
texts=train_df[text_column].tolist(), | |
labels=train_df[label_column].tolist(), | |
tokenizer=tokenizer, | |
max_length=max_length | |
) | |
val_dataset = FakeNewsDataset( | |
texts=val_df[text_column].tolist(), | |
labels=val_df[label_column].tolist(), | |
tokenizer=tokenizer, | |
max_length=max_length | |
) | |
test_dataset = FakeNewsDataset( | |
texts=test_df[text_column].tolist(), | |
labels=test_df[label_column].tolist(), | |
tokenizer=tokenizer, | |
max_length=max_length | |
) | |
# Create data loaders | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=4 | |
) | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4 | |
) | |
test_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4 | |
) | |
return { | |
'train': train_loader, | |
'val': val_loader, | |
'test': test_loader | |
} |