Spaces:
Running
Running
import kaggle | |
from pathlib import Path, PurePosixPath | |
import json | |
try: | |
from __kaggle_login import kaggle_users | |
except ImportError: | |
raise ImportError("Please create a __kaggle_login.py file with a kaggle_users" + | |
"dict containing your Kaggle credentials.") | |
import argparse | |
import sys | |
import subprocess | |
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME | |
from train import get_parser as get_train_parser | |
from typing import Optional | |
from configuration import KAGGLE_DATASET_LIST, NB_ID, GIT_USER, GIT_REPO, TRAIN_SCRIPT | |
def get_git_branch_name(): | |
try: | |
branch_name = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode() | |
return branch_name | |
except subprocess.CalledProcessError: | |
return "Error: Could not determine the Git branch name." | |
def prepare_notebook( | |
output_nb_path: Path, | |
exp: int, | |
branch: str, | |
git_user: str = None, | |
git_repo: str = None, | |
template_nb_path: Path = Path(__file__).parent/"remote_training_template.ipynb", | |
wandb_flag: bool = False, | |
output_dir: Path = "scripts/"+OUTPUT_FOLDER_NAME, | |
dataset_files: Optional[list] = None, | |
train_script: str = TRAIN_SCRIPT | |
): | |
assert git_user is not None, "Please provide a git username for the repo" | |
assert git_repo is not None, "Please provide a git repo name for the repo" | |
expressions = [ | |
("exp", f"{exp}"), | |
("branch", f"\'{branch}\'"), | |
("git_user", f"\'{git_user}\'"), | |
("git_repo", f"\'{git_repo}\'"), | |
("wandb_flag", "True" if wandb_flag else "False"), | |
("output_dir", "None" if output_dir is None else f"\'{output_dir}\'"), | |
("dataset_files", "None" if dataset_files is None else f"{dataset_files}"), | |
("train_script", "\'"+train_script+"\'") | |
] | |
with open(template_nb_path) as f: | |
template_nb = f.readlines() | |
for line_idx, li in enumerate(template_nb): | |
for expr, expr_replace in expressions: | |
if f"!!!{expr}!!!" in li: | |
template_nb[line_idx] = template_nb[line_idx].replace(f"!!!{expr}!!!", expr_replace) | |
template_nb = "".join(template_nb) | |
with open(output_nb_path, "w") as w: | |
w.write(template_nb) | |
def main(argv): | |
parser = argparse.ArgumentParser(description="Train a model on Kaggle using a script") | |
parser.add_argument("-n", "--nb_id", type=str, help="Notebook name in kaggle", default=NB_ID) | |
parser.add_argument("-u", "--user", type=str, help="Kaggle user", choices=list(kaggle_users.keys())) | |
parser.add_argument("--branch", type=str, help="Git branch name", default=get_git_branch_name()) | |
parser.add_argument("-p", "--push", action="store_true", help="Push") | |
parser.add_argument("-d", "--download", action="store_true", help="Download results") | |
get_train_parser(parser) | |
args = parser.parse_args(argv) | |
nb_id = args.nb_id | |
exp_str = "_".join(f"{exp:04d}" for exp in args.exp) | |
kaggle_user = kaggle_users[args.user] | |
uname_kaggle = kaggle_user["username"] | |
kaggle.api._load_config(kaggle_user) | |
if args.download: | |
tmp_dir = ROOT_DIR/f"__tmp_{exp_str}" | |
tmp_dir.mkdir(exist_ok=True, parents=True) | |
kaggle.api.kernels_output_cli(f"{kaggle_user['username']}/{nb_id}", path=str(tmp_dir)) | |
subprocess.run(["tar", "-xzf", tmp_dir/"output.tgz"]) | |
# @FIXME: windows probably does not have tar command | |
import shutil | |
shutil.rmtree(tmp_dir, ignore_errors=True) | |
return | |
kernel_root = ROOT_DIR/f"__nb_{uname_kaggle}" | |
kernel_root.mkdir(exist_ok=True, parents=True) | |
kernel_path = kernel_root/exp_str | |
kernel_path.mkdir(exist_ok=True, parents=True) | |
branch = args.branch | |
config = { | |
"id": str(PurePosixPath(f"{kaggle_user['username']}")/nb_id), | |
"title": nb_id.lower(), | |
"code_file": f"{nb_id}.ipynb", | |
"language": "python", | |
"kernel_type": "notebook", | |
"is_private": "true", | |
"enable_gpu": "true" if not args.cpu else "false", | |
"enable_tpu": "false", | |
"enable_internet": "true", | |
"dataset_sources": KAGGLE_DATASET_LIST, | |
"competition_sources": [], | |
"kernel_sources": [], | |
"model_sources": [] | |
} | |
prepare_notebook((kernel_path/nb_id).with_suffix(".ipynb"), args.exp, branch, | |
git_user=GIT_USER, git_repo=GIT_REPO, wandb_flag=not args.no_wandb) | |
assert (kernel_path/nb_id).with_suffix(".ipynb").exists() | |
with open(kernel_path/"kernel-metadata.json", "w") as f: | |
json.dump(config, f, indent=4) | |
if args.push: | |
kaggle.api.kernels_push_cli(str(kernel_path)) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) | |