from fastai.vision.all import * from fastai.vision.widgets import * from fruit_classifier.config.configuration import ConfigurationManager class ModelTrainer: def __init__(self): config_manager = ConfigurationManager() self.config = config_manager.get_training_config() def random_seed(seed_value, use_cuda): np.random.seed(seed_value) # cpu vars torch.manual_seed(seed_value) # cpu vars random.seed(seed_value) # Python if use_cuda: torch.cuda.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) # gpu vars torch.backends.cudnn.deterministic = True #needed torch.backends.cudnn.benchmark = False def create_dls(self): path = self.config.training_data fruit_db = DataBlock( blocks=(ImageBlock, CategoryBlock), get_items=get_image_files, splitter=RandomSplitter(valid_pct=0.2),#, seed=42), get_y=parent_label, item_tfms=Resize(464), batch_tfms=aug_transforms(size=224, min_scale=0.5) ) dls = fruit_db.dataloaders(path, num_workers=0) return dls def train_model(self): self.random_seed(seed_value=13, use_cuda=True) dls = self.create_dls() learn = vision_learner( dls, resnet18, loss_func=LabelSmoothingCrossEntropy(), metrics=accuracy, cbs=[ MixUp, SaveModelCallback( monitor='accuracy', fname=self.config.params_model_name ), EarlyStoppingCallback( monitor='accuracy', patience=self.config.params_patience ) ] ) learn.fine_tune(self.config.params_epochs, freeze_epochs=self.config.params_n_freeze_epochs) model = vision_learner(dls, resnet18) model.load(self.config.params_model_name) model.export(f'{self.config.trained_model_path}/{self.config.params_model_name}.pkl')