File size: 2,140 Bytes
29793a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastai.vision.all import *
from fastai.vision.widgets import *

from fruit_classifier.config.configuration import ConfigurationManager


class ModelTrainer:

    def __init__(self):
        
        config_manager = ConfigurationManager()
        self.config = config_manager.get_training_config()

    def random_seed(seed_value, use_cuda):
        np.random.seed(seed_value) # cpu vars
        torch.manual_seed(seed_value) # cpu  vars
        random.seed(seed_value) # Python
        if use_cuda: 
            torch.cuda.manual_seed(seed_value)
            torch.cuda.manual_seed_all(seed_value) # gpu vars
            torch.backends.cudnn.deterministic = True  #needed
            torch.backends.cudnn.benchmark = False

    def create_dls(self):

        path = self.config.training_data

        fruit_db = DataBlock(
            blocks=(ImageBlock, CategoryBlock),
            get_items=get_image_files,
            splitter=RandomSplitter(valid_pct=0.2),#, seed=42),
            get_y=parent_label,
            item_tfms=Resize(464),
            batch_tfms=aug_transforms(size=224, min_scale=0.5)
        )

        dls = fruit_db.dataloaders(path, num_workers=0)

        return dls
    
    def train_model(self):

        self.random_seed(seed_value=13, use_cuda=True)

        dls = self.create_dls()

        learn = vision_learner(
            dls, 
            resnet18, 
            loss_func=LabelSmoothingCrossEntropy(), 
            metrics=accuracy, 
            cbs=[
                MixUp, 
                SaveModelCallback(
                    monitor='accuracy',
                    fname=self.config.params_model_name
                ),
                EarlyStoppingCallback(
                    monitor='accuracy', 
                    patience=self.config.params_patience
                )
            ]
        )

        learn.fine_tune(self.config.params_epochs, freeze_epochs=self.config.params_n_freeze_epochs)

        model = vision_learner(dls, resnet18)
        model.load(self.config.params_model_name)
        model.export(f'{self.config.trained_model_path}/{self.config.params_model_name}.pkl')