import atexit
import sacred
import argparse
import time
import math
import subprocess
import shutil
import os
import json
import threading
import requests
import glob
from configs import fetch_model_params
import socket
import subprocess
import queue
import sys
import signal


parser = argparse.ArgumentParser()
parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any
parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters
parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)
parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
parser.add_argument('--autostack', action="store_false")
parser.add_argument('--auto_layout', action="store_true")
parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
parser.add_argument('--new', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--predict', action='store_true')
parser.add_argument('--no_delete_tpu', action='store_true')
parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds
args = parser.parse_args()

params = fetch_model_params(args.model)

ex = sacred.Experiment(args.experiment_name)
ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))


def get_open_port(lo=8000, hi=8100):
    for i in range(lo, hi):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            if s.connect_ex(('localhost', i)) != 0:
                return i


def train_thread(args, tpu, id, q):
    print('starting training on', tpu)

    # pass binary flags through
    opts = ''
    for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:
        if args.__getattribute__(flag):
            opts += ' --' + flag

    for flag in ['autostack', ]:
        if not args.__getattribute__(flag):
            opts += ' --' + flag

    cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)
    print('Running:', cmd)
    proc = subprocess.Popen(cmd, shell=True)

    # poll until it's exited
    while proc.poll() is None:
        time.sleep(60)
        try:
            nq, *nargs = q.get_nowait()
            if nq == 'kill':
                print('train thread recieved kill signal from logging thread')
                # first send SIGTERM
                proc.terminate()

                time.sleep(60)
                
                # if it still hasn't exited, we send SIGKILL
                if proc.poll() is None: 
                    print('SIGTERM not successful, sending SIGKILL')
                    proc.kill()

        except queue.Empty:
            pass

    print('exited training!')
    if proc.returncode == 0:
        print('exited gracefully')
        os.kill(os.getpid(), signal.SIGINT)
        return
    
    if args.no_delete_tpu:
        print('recreate done, exiting train_thread - not killing tpu!')
        return
    print("Recreating {} in 60sec...".format(tpu))
    time.sleep(60)
    os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu))
    print('recreate done, exiting train_thread')
    
    # clear out queue
    while True:
        try:
            q.get_nowait()
            print('dropped request in queue after pu recreate')
        except queue.Empty:
            break


def get_json(uri, params=None, timeout=15):
    resp = requests.get(uri, params=params, timeout=timeout)
    resp.raise_for_status()
    return resp.json()


def get_tag_sets(base_uri):
    j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})
    assert isinstance(j, dict)
    return {
        run: j[run].keys()
        for run in j.keys()
    }


def get_scalar_data(base_uri, run, tag):
    j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})
    assert isinstance(j, list)
    return j


def get_run_data(port):
    base_uri = f'http://localhost:{port}/'
    r = {}
    try:
        tag_sets = get_tag_sets(base_uri)
        runs = tag_sets.keys()
        if '.' in runs:
            if 'loss' in tag_sets['.']:
                r['loss'] = get_scalar_data(base_uri, '.', 'loss')
        if 'eval' in runs:
            if 'loss' in tag_sets['eval']:
                r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')
        if 'eval_lambada' in runs:
            if 'lambada_acc' in tag_sets['eval_lambada']:
                r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')
            if 'lambada_log_ppl' in tag_sets['eval_lambada']:
                r['lambada_ppl'] = [
                    [t, s, math.exp(lp)]
                    for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')
                ]
    except:
        import traceback
        traceback.print_exc()
    return r


@ex.main
def main(_run):
    print('Starting run', _run._id)
    print('experiment main invoked with argv:', " ".join(sys.argv))
    print('WARNING: please remember to remove old metric log files from the model directory.')

    os.makedirs('run_configs', exist_ok=True)
    shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))

    tensorboard_port = get_open_port()
    print('Tensorboard at port:', tensorboard_port)
    print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))
    os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,))
    atexit.register(goodbye, _run._id)

    curr_step = {}
    seen_predictions = set()

    heartbeat_timeout = args.initial_heartbeat_timeout * 2
    while True:
        last_tb_log_time = time.time()
        start_time = time.time()
        q = queue.Queue()
        trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))
        trainthd.start()

        while trainthd.is_alive():
            time.sleep(60)

            if start_time + args.initial_heartbeat_timeout < time.time():
                # after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower
                heartbeat_timeout = args.heartbeat_timeout

            print('Polling tensorboard for metrics...')
            data = get_run_data(tensorboard_port)
            for k in data.keys():
                for ts, step, val in data[k]:
                    if step <= curr_step.get(k, -1):
                        continue
                    _run.log_scalar(k, val, step)
                    if k == 'loss':
                        _run.log_scalar('tb_ts', ts, step)
                        print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))
                    
                    # found something new, so logging!
                    last_tb_log_time = time.time()

                    curr_step[k] = step

            for f in glob.glob('predictions_{}_*'.format(_run._id)):
                if f in seen_predictions:
                    continue
                print('collecting prediction file', f)
                ex.add_artifact(f)
                
                seen_predictions.add(f)
            
            # collect eval metrics from jsonl
            if os.path.exists(f'eval_{_run._id}.jsonl'):
                with open(f'eval_{_run._id}.jsonl') as fh:
                    for line in fh:
                        ob = json.loads(line)
                        val_step = ob['global_step']
                        val_task = ob['task']
                        for metr in ob.keys():
                            k = 'fs.' + val_task + '.' + metr
                            if metr in ['task', 'global_step']: continue
                            if val_step <= curr_step.get(k, -1): continue
                            _run.log_scalar(k, ob[metr], val_step)
                            curr_step[k] = val_step

            if time.time() - last_tb_log_time > heartbeat_timeout:
                # the run hasn't logged in a while, so we restart it
                q.put(('kill',))

                # give training thread some time to do its thing and recreate tpu
                while trainthd.is_alive():
                    print('logging thread waiting for killing stalled run and for tpu recreate to finish')
                    time.sleep(60)
                
                # reset heartbeat timeout to initial
                heartbeat_timeout = args.initial_heartbeat_timeout
                last_tb_log_time = time.time()


        if args.no_delete_tpu:
            break


def goodbye(id):
    print("You are now leaving the Python sector.")
    print("Sie verlassen den pythonischen Sektor.")

    os.system("screen -S tensorboard_{} -X quit".format(id))

        
if __name__ == '__main__':
    for file in glob.glob("**/*", recursive=True):
        if file.split('.')[-1] in ['py']:
            print('Adding', file, 'to sacred')
            ex.add_source_file(file)

    ex.add_config({
        'tpu_name': args.tpu,
        **params
    })

    ex.run()