File size: 1,271 Bytes
cb7cafb
a637525
cb7cafb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
from src.textsummarizer.logging import logger
from transformers import AutoTokenizer
from datasets import load_dataset, load_from_disk
from textsummarizer.entity.config_entity import DataTransformationConfig

class DataTransformation:
    def __init__(self, config : DataTransformationConfig):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
        
        
    def convert_examples_to_features(self, example_batch):
        input_encoding = self.tokenizer(example_batch['dialogue'], max_length = 1024, truncation = True)
        
        with self.tokenizer.as_target_tokenizer():
            target_encodings = self.tokenizer(example_batch['summary'], max_length = 128, truncation = True )
            
        return {
            'input_ids' : input_encoding['input_ids'],
            'attention_mask': input_encoding['attention_mask'],
            'labels': target_encodings['input_ids']
        }
        
    def convert(self):
        dataset_samsum = load_from_disk(self.config.data_path)
        dataset_samsum_pt = dataset_samsum.map(self.convert_examples_to_features, batched = True)
        dataset_samsum_pt.save_to_disk(os.path.join(self.config.root_dir,"samsum_dataset"))