Vahe commited on
Commit
29793a4
·
1 Parent(s): f1defb2
artifacts/model/resnet18_fruits.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9aab7287020d56634eeb010928b0b5e85e85460df6a4737f254332a9733dd19
3
+ size 47009876
config/config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ artifacts_root: artifacts
2
+ data_root: artifacts/data
3
+ model_root: artifacts/model
params.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ N_EPOCHS: 100
2
+ N_CLASSES: 10
3
+ N_FREEZE_EPOCHS: 5
4
+ PATIENCE: 15
5
+ MODEL_NAME: 'resnet18_fruits'
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastai==2.7.12
2
+ numpy==1.23.5
3
+ opencv-python==4.9.0.80
4
+ Pillow==9.5.0
5
+ streamlit==1.24.0
src/fruit_classifier/__init__.py ADDED
File without changes
src/fruit_classifier/components/__init__.py ADDED
File without changes
src/fruit_classifier/components/model_trainer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision.all import *
2
+ from fastai.vision.widgets import *
3
+
4
+ from fruit_classifier.config.configuration import ConfigurationManager
5
+
6
+
7
+ class ModelTrainer:
8
+
9
+ def __init__(self):
10
+
11
+ config_manager = ConfigurationManager()
12
+ self.config = config_manager.get_training_config()
13
+
14
+ def random_seed(seed_value, use_cuda):
15
+ np.random.seed(seed_value) # cpu vars
16
+ torch.manual_seed(seed_value) # cpu vars
17
+ random.seed(seed_value) # Python
18
+ if use_cuda:
19
+ torch.cuda.manual_seed(seed_value)
20
+ torch.cuda.manual_seed_all(seed_value) # gpu vars
21
+ torch.backends.cudnn.deterministic = True #needed
22
+ torch.backends.cudnn.benchmark = False
23
+
24
+ def create_dls(self):
25
+
26
+ path = self.config.training_data
27
+
28
+ fruit_db = DataBlock(
29
+ blocks=(ImageBlock, CategoryBlock),
30
+ get_items=get_image_files,
31
+ splitter=RandomSplitter(valid_pct=0.2),#, seed=42),
32
+ get_y=parent_label,
33
+ item_tfms=Resize(464),
34
+ batch_tfms=aug_transforms(size=224, min_scale=0.5)
35
+ )
36
+
37
+ dls = fruit_db.dataloaders(path, num_workers=0)
38
+
39
+ return dls
40
+
41
+ def train_model(self):
42
+
43
+ self.random_seed(seed_value=13, use_cuda=True)
44
+
45
+ dls = self.create_dls()
46
+
47
+ learn = vision_learner(
48
+ dls,
49
+ resnet18,
50
+ loss_func=LabelSmoothingCrossEntropy(),
51
+ metrics=accuracy,
52
+ cbs=[
53
+ MixUp,
54
+ SaveModelCallback(
55
+ monitor='accuracy',
56
+ fname=self.config.params_model_name
57
+ ),
58
+ EarlyStoppingCallback(
59
+ monitor='accuracy',
60
+ patience=self.config.params_patience
61
+ )
62
+ ]
63
+ )
64
+
65
+ learn.fine_tune(self.config.params_epochs, freeze_epochs=self.config.params_n_freeze_epochs)
66
+
67
+ model = vision_learner(dls, resnet18)
68
+ model.load(self.config.params_model_name)
69
+ model.export(f'{self.config.trained_model_path}/{self.config.params_model_name}.pkl')
src/fruit_classifier/config/__init__.py ADDED
File without changes
src/fruit_classifier/config/configuration.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fruit_classifier.constants import *
2
+ from fruit_classifier.utils.utils import read_yaml, create_directories
3
+ from fruit_classifier.entity.entity import DataConfig, TrainingConfig
4
+
5
+
6
+ class ConfigurationManager:
7
+
8
+ def __init__(
9
+ self,
10
+ config_filepath = CONFIG_FILE_PATH,
11
+ params_filepath = PARAMS_FILE_PATH):
12
+
13
+ self.config = read_yaml(config_filepath)
14
+ self.params = read_yaml(params_filepath)
15
+
16
+ create_directories([self.config['artifacts_root']])
17
+
18
+
19
+
20
+ def get_data_config(self) -> DataConfig:
21
+ config = self.config['data_root']
22
+
23
+ create_directories([config])
24
+
25
+ data_config = DataConfig(
26
+ root_dir=config['data_root']
27
+ )
28
+
29
+ return data_config
30
+
31
+
32
+ def get_training_config(self) -> TrainingConfig:
33
+ training = self.config['model_root']
34
+ params = self.params
35
+ training_data = self.config['data_root']
36
+ create_directories([
37
+ Path(training)
38
+ ])
39
+
40
+ training_config = TrainingConfig(
41
+ trained_model_path=Path(training),
42
+ training_data=Path(training_data),
43
+ params_epochs=params['N_EPOCHS'],
44
+ params_n_classes=params['N_CLASSES'],
45
+ params_n_freeze_epochs=params['N_FREEZE_EPOCHS'],
46
+ params_patience=params['PATIENCE'],
47
+ params_model_name=params['MODEL_NAME']
48
+ )
49
+
50
+ return training_config
src/fruit_classifier/constants/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ CONFIG_FILE_PATH = Path("config/config.yaml")
4
+ PARAMS_FILE_PATH = Path("params.yaml")
src/fruit_classifier/entity/__init__.py ADDED
File without changes
src/fruit_classifier/entity/entity.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class DataConfig:
7
+ root_dir: Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class TrainingConfig:
12
+ trained_model_path: Path
13
+ training_data: Path
14
+ params_epochs: int
15
+ params_n_classes: int
16
+ params_n_freeze_epochs: int
17
+ params_patience: int
18
+ params_model_name: str
src/fruit_classifier/exception/__init__.py ADDED
File without changes
src/fruit_classifier/exception/exception.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ class CustomException(Exception):
5
+
6
+ def __init__(self, error_message, error_details:sys):
7
+ self.error_message = error_message
8
+ _, _, exc_tb = error_details.exc_info()
9
+ self.lineno = exc_tb.tb_lineno
10
+ self.file_name = exc_tb.tb_frame.f_code.co_filename
11
+
12
+ def __str__(self):
13
+ return "Error occured in python script name [{0}] line number [{1}] error message [{2}]".format(
14
+ self.file_name, self.lineno, str(self.error_message))
15
+
16
+ if __name__ == '__main__':
17
+ try:
18
+ 1 / 0
19
+ except Exception as e:
20
+ raise CustomException(e, sys)
src/fruit_classifier/logger/__init__.py ADDED
File without changes
src/fruit_classifier/logger/logger.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import datetime as dt
4
+
5
+ LOG_FILE = f"{dt.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
6
+
7
+ log_path = os.path.join(os.getcwd(), "logs")
8
+
9
+ os.makedirs(log_path, exist_ok=True)
10
+
11
+ LOG_FILEPATH = os.path.join(log_path, LOG_FILE)
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ filename=LOG_FILEPATH,
16
+ format="[%(asctime)s] %(lineno)d %(name)s - %(levelname)s - %(message)s"
17
+ )
18
+
19
+ if __name__ == '__main__':
20
+ logging.info("Log testing executed!!!")
src/fruit_classifier/pipeline/__init__.py ADDED
File without changes
src/fruit_classifier/pipeline/training_pipeline.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from fruit_classifier.components.model_trainer import ModelTrainer
2
+
3
+ model_trainer = ModelTrainer()
4
+ model_trainer.train_model()
src/fruit_classifier/utils/__init__.py ADDED
File without changes
src/fruit_classifier/utils/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import sys
4
+
5
+ from fruit_classifier.logger.logger import logging
6
+ from fruit_classifier.exception.exception import CustomException
7
+
8
+
9
+ def read_yaml(path_to_yaml):
10
+
11
+ try:
12
+ with open(path_to_yaml) as yaml_file:
13
+ content = yaml.safe_load(yaml_file)
14
+ logging.info(f"yaml file: {path_to_yaml} loaded successfully")
15
+ return content
16
+ except Exception as e:
17
+ logging.info('Error occured while reading the yaml file.')
18
+ raise CustomException(e, sys)
19
+
20
+
21
+ def create_directories(path_to_directories):
22
+
23
+ for path in path_to_directories:
24
+ os.makedirs(path, exist_ok=True)
25
+ logging.info(f"created directory at: {path}")
26
+
27
+
28
+