image-deblurring / scripts /remote_training.py
balthou's picture
initiate demo
cec5823
import kaggle
from pathlib import Path, PurePosixPath
import json
try:
from __kaggle_login import kaggle_users
except ImportError:
raise ImportError("Please create a __kaggle_login.py file with a kaggle_users" +
"dict containing your Kaggle credentials.")
import argparse
import sys
import subprocess
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME
from train import get_parser as get_train_parser
from typing import Optional
from configuration import KAGGLE_DATASET_LIST, NB_ID, GIT_USER, GIT_REPO, TRAIN_SCRIPT
def get_git_branch_name():
try:
branch_name = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode()
return branch_name
except subprocess.CalledProcessError:
return "Error: Could not determine the Git branch name."
def prepare_notebook(
output_nb_path: Path,
exp: int,
branch: str,
git_user: str = None,
git_repo: str = None,
template_nb_path: Path = Path(__file__).parent/"remote_training_template.ipynb",
wandb_flag: bool = False,
output_dir: Path = "scripts/"+OUTPUT_FOLDER_NAME,
dataset_files: Optional[list] = None,
train_script: str = TRAIN_SCRIPT
):
assert git_user is not None, "Please provide a git username for the repo"
assert git_repo is not None, "Please provide a git repo name for the repo"
expressions = [
("exp", f"{exp}"),
("branch", f"\'{branch}\'"),
("git_user", f"\'{git_user}\'"),
("git_repo", f"\'{git_repo}\'"),
("wandb_flag", "True" if wandb_flag else "False"),
("output_dir", "None" if output_dir is None else f"\'{output_dir}\'"),
("dataset_files", "None" if dataset_files is None else f"{dataset_files}"),
("train_script", "\'"+train_script+"\'")
]
with open(template_nb_path) as f:
template_nb = f.readlines()
for line_idx, li in enumerate(template_nb):
for expr, expr_replace in expressions:
if f"!!!{expr}!!!" in li:
template_nb[line_idx] = template_nb[line_idx].replace(f"!!!{expr}!!!", expr_replace)
template_nb = "".join(template_nb)
with open(output_nb_path, "w") as w:
w.write(template_nb)
def main(argv):
parser = argparse.ArgumentParser(description="Train a model on Kaggle using a script")
parser.add_argument("-n", "--nb_id", type=str, help="Notebook name in kaggle", default=NB_ID)
parser.add_argument("-u", "--user", type=str, help="Kaggle user", choices=list(kaggle_users.keys()))
parser.add_argument("--branch", type=str, help="Git branch name", default=get_git_branch_name())
parser.add_argument("-p", "--push", action="store_true", help="Push")
parser.add_argument("-d", "--download", action="store_true", help="Download results")
get_train_parser(parser)
args = parser.parse_args(argv)
nb_id = args.nb_id
exp_str = "_".join(f"{exp:04d}" for exp in args.exp)
kaggle_user = kaggle_users[args.user]
uname_kaggle = kaggle_user["username"]
kaggle.api._load_config(kaggle_user)
if args.download:
tmp_dir = ROOT_DIR/f"__tmp_{exp_str}"
tmp_dir.mkdir(exist_ok=True, parents=True)
kaggle.api.kernels_output_cli(f"{kaggle_user['username']}/{nb_id}", path=str(tmp_dir))
subprocess.run(["tar", "-xzf", tmp_dir/"output.tgz"])
# @FIXME: windows probably does not have tar command
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
return
kernel_root = ROOT_DIR/f"__nb_{uname_kaggle}"
kernel_root.mkdir(exist_ok=True, parents=True)
kernel_path = kernel_root/exp_str
kernel_path.mkdir(exist_ok=True, parents=True)
branch = args.branch
config = {
"id": str(PurePosixPath(f"{kaggle_user['username']}")/nb_id),
"title": nb_id.lower(),
"code_file": f"{nb_id}.ipynb",
"language": "python",
"kernel_type": "notebook",
"is_private": "true",
"enable_gpu": "true" if not args.cpu else "false",
"enable_tpu": "false",
"enable_internet": "true",
"dataset_sources": KAGGLE_DATASET_LIST,
"competition_sources": [],
"kernel_sources": [],
"model_sources": []
}
prepare_notebook((kernel_path/nb_id).with_suffix(".ipynb"), args.exp, branch,
git_user=GIT_USER, git_repo=GIT_REPO, wandb_flag=not args.no_wandb)
assert (kernel_path/nb_id).with_suffix(".ipynb").exists()
with open(kernel_path/"kernel-metadata.json", "w") as f:
json.dump(config, f, indent=4)
if args.push:
kaggle.api.kernels_push_cli(str(kernel_path))
if __name__ == '__main__':
main(sys.argv[1:])