Spaces:
Runtime error
Runtime error
File size: 2,140 Bytes
29793a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from fastai.vision.all import *
from fastai.vision.widgets import *
from fruit_classifier.config.configuration import ConfigurationManager
class ModelTrainer:
def __init__(self):
config_manager = ConfigurationManager()
self.config = config_manager.get_training_config()
def random_seed(seed_value, use_cuda):
np.random.seed(seed_value) # cpu vars
torch.manual_seed(seed_value) # cpu vars
random.seed(seed_value) # Python
if use_cuda:
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # gpu vars
torch.backends.cudnn.deterministic = True #needed
torch.backends.cudnn.benchmark = False
def create_dls(self):
path = self.config.training_data
fruit_db = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2),#, seed=42),
get_y=parent_label,
item_tfms=Resize(464),
batch_tfms=aug_transforms(size=224, min_scale=0.5)
)
dls = fruit_db.dataloaders(path, num_workers=0)
return dls
def train_model(self):
self.random_seed(seed_value=13, use_cuda=True)
dls = self.create_dls()
learn = vision_learner(
dls,
resnet18,
loss_func=LabelSmoothingCrossEntropy(),
metrics=accuracy,
cbs=[
MixUp,
SaveModelCallback(
monitor='accuracy',
fname=self.config.params_model_name
),
EarlyStoppingCallback(
monitor='accuracy',
patience=self.config.params_patience
)
]
)
learn.fine_tune(self.config.params_epochs, freeze_epochs=self.config.params_n_freeze_epochs)
model = vision_learner(dls, resnet18)
model.load(self.config.params_model_name)
model.export(f'{self.config.trained_model_path}/{self.config.params_model_name}.pkl')
|