#!/usr/bin/env python

#
# This tool deletes checkpoints found at given path that are no longer needed
#
# we have 2 parts to each checkpoints to cleanup
#
# 1. the original deepspeed checkpoint
# 2. the converted hf checkpoint
#
# we will start with a combined requirement for eval to be completed and s3 synced to nuke the checkpoint
#
# Example:
#
# ./cleanup-checkpoints.py checkpoints-path
#
# Use `-h` for more options

import argparse
import shutil  # noqa
import subprocess
import sys
import time
from pathlib import Path


repo_path = Path(__file__).parents[2]

# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
#
# If that file is there it means:
#
# 1. either the cleanup is still running
# 2. the cleanup got aborted (e.g. cpu-oom)
#
# to detect aborted cleanups we will check if the control file is older than a reasonable time to perform such a cleanup
control_file_name = "started-cleanup-checkpoint"
finished_uploading_file_name = "finished-upload-checkpoint"
# should fine tune - but surely 1h per checkpoint is plenty
reasonable_cleanup_time_in_secs = 1 * 60 * 60


def run_cmd(cmd, check=True):
    try:
        response = subprocess.run(
            cmd,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
            check=check,
            encoding="utf-8",
        ).stdout.strip()
    except subprocess.CalledProcessError as exc:
        raise EnvironmentError(exc.stderr)

    return response


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
    parser.add_argument("--skip-evals-check", action="store_true", help="skip evals done checks")
    return parser.parse_args()


def exit(msg):
    print(msg)
    sys.exit()


def should_process(path, control_file_path, args):
    """Heuristics to decide whether to cleanup this opt_step-XXX checkpoint or not"""

    s3_completed_path = path / finished_uploading_file_name
    eval_completed_paths = [
        path / "run_evals_0_shots_done",
        path / "run_evals_4_shots_done",
        path / "run_evals_perplexity_validation_done",
        path / "run_evals_0_shots_a_la_flamingo_done",
    ]

    # check s3 sync is completed
    if not s3_completed_path.exists():
        print(f"[N] {path} hasn't been synced to s3 yet. Skipping")
        return False

    # check evals are completed
    if not args.skip_evals_check:
        for eval_path in eval_completed_paths:
            if not eval_path.exists():
                print(f"[N] {path} hasn't been evaled yet. Skipping")
                return False

    # complicated checks - has another job already started processing? or did it crash?
    if control_file_path.exists():
        if control_file_path.stat().st_mtime < time.time() - reasonable_cleanup_time_in_secs:
            print(f"[Y] {path} looks stale - probably aborted cleanup job. Deleting")
            return True
        else:
            print(
                f"[N] {path} either another job is doing the cleanup or less than"
                f" {reasonable_cleanup_time_in_secs} secs has passed since it was launched. Skipping"
            )
            return False
    else:
        print(f"[Y] {path} completed s3 sync + eval. Deleting")
        return True


def main():
    args = get_args()

    checkpoints_path = Path(args.checkpoints_path)
    if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
        raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")

    checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
    if len(checkpoint_dirs) == 0:
        exit("No checkpoints found, exiting")

    # Check each checkpoint folder in real time to allow for overlapping jobs starting at different times
    # Additionally do not delete the last 2 checkpoints
    #
    # sort numerically to sort correctly different number of digits: opt_step-10, opt_step-100
    checkpoint_dirs_sorted = sorted(checkpoint_dirs, key=lambda x: int(str(x).split("-")[-1]))
    for i, checkpoint_dir in enumerate(checkpoint_dirs_sorted):
        print(f"\n*** Checking {checkpoint_dir}")

        if i + 1 == len(checkpoint_dirs_sorted):
            print(f"[N] {checkpoint_dir} is a last checkpoint. Skipping")
            continue

        if i + 2 == len(checkpoint_dirs_sorted):
            print(f"[N] {checkpoint_dir} is a second to last checkpoint. Skipping")
            continue

        control_file_path = checkpoint_dir / "unwrapped_model" / control_file_name

        if not should_process(checkpoint_dir, control_file_path, args):
            continue

        print(f"Launching cleanup for {checkpoint_dir}")
        # we could use flock here, to avoid a race condition, but it'd be pointless since each
        # cronjob is likely to run on a different node and flock only works within a single node
        control_file_path.touch()

        # cleanup
        # XXX: enable the actual delete once tested a lot
        # The delete should be relatively safe since it'll only run if it finds 2 files:
        # save_dir/opt_step-XXX/s3_sync_is_completed save_dir/opt_step-XXX/eval_is_completed
        shutil.rmtree(checkpoint_dir, ignore_errors=True)
        print(f"Checkpoint {checkpoint_dir} deleted")


if __name__ == "__main__":
    main()