File size: 3,065 Bytes
469c254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from typing import Dict, List, Union
import pandas as pd
import numpy as np

class FakeNewsDataset(Dataset):
    def __init__(self,
                 texts: List[str],
                 labels: List[int],
                 tokenizer: BertTokenizer,
                 max_length: int = 512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self) -> int:
        return len(self.texts)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def create_data_loaders(
    df: pd.DataFrame,
    text_column: str,
    label_column: str,
    tokenizer: BertTokenizer,
    batch_size: int = 32,
    max_length: int = 512,
    train_size: float = 0.8,
    val_size: float = 0.1,
    random_state: int = 42
) -> Dict[str, torch.utils.data.DataLoader]:
    """Create train, validation, and test data loaders."""
    # Split data
    train_df = df.sample(frac=train_size, random_state=random_state)
    remaining_df = df.drop(train_df.index)
    val_df = remaining_df.sample(frac=val_size/(1-train_size), random_state=random_state)
    test_df = remaining_df.drop(val_df.index)
    
    # Create datasets
    train_dataset = FakeNewsDataset(
        texts=train_df[text_column].tolist(),
        labels=train_df[label_column].tolist(),
        tokenizer=tokenizer,
        max_length=max_length
    )
    
    val_dataset = FakeNewsDataset(
        texts=val_df[text_column].tolist(),
        labels=val_df[label_column].tolist(),
        tokenizer=tokenizer,
        max_length=max_length
    )
    
    test_dataset = FakeNewsDataset(
        texts=test_df[text_column].tolist(),
        labels=test_df[label_column].tolist(),
        tokenizer=tokenizer,
        max_length=max_length
    )
    
    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )
    
    return {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }