File size: 3,941 Bytes
34b369f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Configuration manager to get and set all the configuration.
"""

from pathlib import Path

from box import ConfigBox

from src.TextSummarizer.constants import file_path
from src.TextSummarizer.entity import entities
from src.TextSummarizer.utils.general import create_directories, read_yaml


# Create a config manager.
class ConfigManager:
    """
    Class to manage the configuration files.
    """

    def __init__(self) -> None:
        self.config: ConfigBox = read_yaml(Path(file_path.CONFIG_FILE_PATH))
        self.params: ConfigBox = read_yaml(Path(file_path.PARAMS_FILE_PATH))

        create_directories(path_to_directories=[self.config.artifacts_root])

    def get_data_ingestion_config(self) -> entities.DataIngestionConfig:
        """
        Get the config which is needed to download the data files.
        """
        config: ConfigBox = self.config.data_ingestion

        data_ingestion_config: entities.DataIngestionConfig = entities.DataIngestionConfig(
            dataset_name=config.dataset_name,
            arrow_dataset_dir=config.arrow_dataset_dir,
        )

        return data_ingestion_config

    def get_data_validation_config(self) -> entities.DataValidationConfig:
        """
        Get the config which is needed to validate the data files.
        """
        config = self.config.data_validation

        create_directories([config.root_dir])

        data_validation_config = entities.DataValidationConfig(
            root_dir=config.root_dir,
            status_file=config.status_file,
            all_required_folders=config.all_required_folders,
        )

        return data_validation_config

    def get_data_transformation_config(self) -> entities.DataTransformationConfig:
        """
        Get teh data transformation configurations.
        """
        config = self.config.data_transformation

        create_directories([config.root_dir])

        data_transformation_config = entities.DataTransformationConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            tokenizer_name = config.tokenizer_name
        )

        return data_transformation_config

    def get_model_trainer_config(self) -> entities.ModelTrainerConfig:
        """
        Get the configuration which is needed to train the model.
        """
        config = self.config.model_trainer
        params = self.params.TrainingArguments

        create_directories([config.root_dir])

        model_trainer_config = entities.ModelTrainerConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            model_path= config.model_path,
            tokenizer_path= config.tokenizer_path,
            model_ckpt = config.model_ckpt,
            num_train_epochs = params.num_train_epochs,
            warmup_steps = params.warmup_steps,
            per_device_train_batch_size = params.per_device_train_batch_size,
            weight_decay = params.weight_decay,
            logging_steps = params.logging_steps,
            evaluation_strategy = params.evaluation_strategy,
            eval_steps = params.evaluation_strategy,
            save_steps = params.save_steps,
            gradient_accumulation_steps = params.gradient_accumulation_steps
        )

        return model_trainer_config

    def get_model_evaluation_config(self) -> entities.ModelEvaluationConfig:
        """
        Get the model evaluation configuration.
        """
        config = self.config.model_evaluation

        create_directories([config.root_dir])

        model_evaluation_config = entities.ModelEvaluationConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            model_path = config.model_path,
            tokenizer_path = config.tokenizer_path,
            metric_file_name = config.metric_file_name,
            hub_model_name=config.hub_model_name

        )

        return model_evaluation_config