import glob
import os
import pickle

import torch


def _remove_files(files):
    for f in files:
        return os.remove(f)


def assert_dir_exits(path):
    if not os.path.exists(path):
        os.makedirs(path)


def save_model(model, epoch, out_path):
    assert_dir_exits(out_path)
    model_file = out_path + str(epoch) + '.pth'
    chk_files = glob.glob(out_path + '*.pth')
    _remove_files(chk_files)
    torch.save(model.state_dict(), model_file)
    print('model saved for epoch: {}'.format(epoch))
    return model_file


def save_objects(obj, epoch, out_path):
    assert_dir_exits(out_path)
    dat_files = glob.glob(out_path + '*.dat')
    _remove_files(dat_files)
    # object should be tuple
    with open(out_path + str(epoch) + '.dat', 'wb') as output:
        pickle.dump(obj, output)

    print('objects saved for epoch: {}'.format(epoch))


def restore_model(model, out_path):
    chk_file = glob.glob(out_path + '*.pth')

    if chk_file:
        chk_file = str(chk_file[0])
        print('found modeL {}, restoring'.format(chk_file))
        model.load_state_dict(torch.load(chk_file))
    else:
        print('Model not found, using untrained model')
    return model


def restore_objects(out_path, default):
    data_file = glob.glob(out_path + '*.dat')
    if data_file:
        data_file = str(data_file[0])
        print('found data {}, restoring'.format(data_file))
        with open(data_file, 'rb') as input_:
            obj = pickle.load(input_)

        return obj
    else:
        return default