File size: 5,943 Bytes
dc07399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np

import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
import torch
import os
import ast
from sklearn.utils import shuffle
import random
from spacy.lang.en import English
from .utils import sentencepiece

def make_dataset(csv_file, tokenizer, max_length=512, padding=None, random_state=1000, data_cut=None, sentence_piece=True):
        ''' data load '''
        ''' 1기+2기 데이터 '''
        #data = csv_file
        #total_data = pd.read_csv(data)

        ''' 재선이형이 준 데이터 '''
        total_data = pd.read_csv(csv_file)
        total_data.columns=['paragraph', 'category', 'position', 'portion']
        label_dict = {'Abstract':0, 'Introduction':1, 'Main':2, 'Methods':3, 'Summary':4, 'Captions':5}
        total_data['label'] = total_data.category.replace(label_dict)
        
        if not data_cut is None:
            total_data = total_data.iloc[:data_cut,:]
            
        total_text = total_data['paragraph'].to_list()
        total_label = total_data['label'].to_list()
        total_position = total_data['position'].to_list()
        total_portion = total_data['portion'].to_list()

        ''' type error 방지 '''
        if type(total_label[0]) == str:
            total_label = [ast.literal_eval(l) for l in total_label]

        if type(total_label[0]) == int:
            total_label = np.eye(6)[total_label].tolist()

        train_text, val_text, train_labels, val_labels, train_position, val_position, train_portion, val_portion = train_test_split(total_text, total_label, total_position, total_portion, test_size=0.2, random_state=random_state, stratify=total_label)

        ''' data들 tokenizing '''
        if not sentence_piece:
            train_encodings= tokenizer.batch_encode_plus(train_text, truncation=True, return_token_type_ids=True, max_length=max_length, add_special_tokens=True, return_attention_mask=True, padding='max_length')
            val_encodings = tokenizer.batch_encode_plus(val_text, truncation=True, return_token_type_ids=True, max_length=max_length, add_special_tokens=True, return_attention_mask=True, padding='max_length')
        else:
            nlp = English()
            nlp.add_pipe('sentencizer')
            train_encodings = sentencepiece(train_text, nlp, tokenizer, max_length=max_length)
            val_encodings = sentencepiece(val_text, nlp, tokenizer, max_length=max_length)

        ''' token tensor 화 '''
        train_encodings = {key: torch.tensor(val) for key, val in train_encodings.items()}
        val_encodings = {key: torch.tensor(val) for key, val in val_encodings.items()}

        ''' labels tensor 화 '''
        train_labels_ = {}
        train_labels_['label_onehot'] = torch.tensor(train_labels, dtype=torch.float)
        train_labels_['label'] = torch.tensor([t.index(1) for t in train_labels], dtype=torch.int)
        train_labels = train_labels_

        val_labels_ = {}
        val_labels_['label_onehot'] = torch.tensor(val_labels, dtype=torch.float)
        val_labels_['label'] = torch.tensor([t.index(1) for t in val_labels], dtype=torch.long)
        val_labels = val_labels_

        ''' position tensor 화 '''
        train_positions_ = {}
        train_positions_['position'] = torch.tensor(train_position, dtype=torch.float)
        train_positions_['portion'] = torch.tensor(train_portion, dtype=torch.float)
        train_positions = train_positions_
        
        val_positions_ = {}
        val_positions_['position'] = torch.tensor(val_position, dtype=torch.float)
        val_positions_['portion'] = torch.tensor(val_portion, dtype=torch.float)
        val_positions = val_positions_

        ''' dataset class 생성 '''
        class CustomDataset(torch.utils.data.Dataset):
            def __init__(self, encodings, labels, texts, positions):
                self.encodings = encodings
                self.labels = labels
                self.texts = texts
                self.positions = positions

            def __getitem__(self, idx):
                item = {key: val[idx] for key, val in self.encodings.items()}
                item['text'] = self.texts[idx]
                # scalar version
                item['label'] = self.labels['label'][idx]
                # one-hot version
                item['label_onehot'] = self.labels['label_onehot'][idx]
                # position
                item['position'] = self.positions['position'][idx]
                #portion
                item['portion'] = self.positions['portion'][idx]
                return item

            def __len__(self):
                return len(self.labels['label_onehot'])

        ''' train을 위한 format으로 data들 변환 '''
        train_dataset = CustomDataset(train_encodings, train_labels, train_text, train_positions)
        val_dataset = CustomDataset(val_encodings, val_labels, val_text, val_positions)
        
        return train_dataset, val_dataset


def make_extract_dataset(paragraphs, positions, tokenizer, max_length):
    encodings = tokenizer.batch_encode_plus(paragraphs, truncation=True, return_token_type_ids=True, max_length=max_length, add_special_tokens=True, return_attention_mask=True, padding='max_length', return_tensors='pt')
    positions_ = {}
    positions_['position'] = torch.tensor(positions, dtype=torch.float)
    positions = positions_
    
    class CustomDataset(torch.utils.data.Dataset):
        def __init__(self, encodings, positions):
            self.encodings = encodings
            self.positions = positions

        def __getitem__(self, idx):
            item = {key: val[idx] for key, val in self.encodings.items()}
            # position
            item['position'] = self.positions['position'][idx]
            return item

        def __len__(self):
            return len(self.encodings['input_ids'])
    
    dataset = CustomDataset(encodings, positions)
    return dataset