Spaces:
Runtime error
Runtime error
from fastai.vision.all import * | |
from fastai.vision.widgets import * | |
from fruit_classifier.config.configuration import ConfigurationManager | |
class ModelTrainer: | |
def __init__(self): | |
config_manager = ConfigurationManager() | |
self.config = config_manager.get_training_config() | |
def random_seed(seed_value, use_cuda): | |
np.random.seed(seed_value) # cpu vars | |
torch.manual_seed(seed_value) # cpu vars | |
random.seed(seed_value) # Python | |
if use_cuda: | |
torch.cuda.manual_seed(seed_value) | |
torch.cuda.manual_seed_all(seed_value) # gpu vars | |
torch.backends.cudnn.deterministic = True #needed | |
torch.backends.cudnn.benchmark = False | |
def create_dls(self): | |
path = self.config.training_data | |
fruit_db = DataBlock( | |
blocks=(ImageBlock, CategoryBlock), | |
get_items=get_image_files, | |
splitter=RandomSplitter(valid_pct=0.2),#, seed=42), | |
get_y=parent_label, | |
item_tfms=Resize(464), | |
batch_tfms=aug_transforms(size=224, min_scale=0.5) | |
) | |
dls = fruit_db.dataloaders(path, num_workers=0) | |
return dls | |
def train_model(self): | |
self.random_seed(seed_value=13, use_cuda=True) | |
dls = self.create_dls() | |
learn = vision_learner( | |
dls, | |
resnet18, | |
loss_func=LabelSmoothingCrossEntropy(), | |
metrics=accuracy, | |
cbs=[ | |
MixUp, | |
SaveModelCallback( | |
monitor='accuracy', | |
fname=self.config.params_model_name | |
), | |
EarlyStoppingCallback( | |
monitor='accuracy', | |
patience=self.config.params_patience | |
) | |
] | |
) | |
learn.fine_tune(self.config.params_epochs, freeze_epochs=self.config.params_n_freeze_epochs) | |
model = vision_learner(dls, resnet18) | |
model.load(self.config.params_model_name) | |
model.export(f'{self.config.trained_model_path}/{self.config.params_model_name}.pkl') | |