#!/usr/bin/env python

from __future__ import print_function, absolute_import

import argparse
import os
import sys;sys.path.insert(1, ".")  # Do not remove this
import traceback
from sagify_base.training.training import train as train_function


# The default path arguments values are used when training happens in SageMaker
# Arguments are provided if you want to run/test this script as a normal python script locally.
_DEFAULT_PREFIX_PATH = '/opt/ml/'


def _parse_args():
    parser = argparse.ArgumentParser()

    required = parser.add_argument_group('required arguments')

    required.add_argument(
        '-i', '--input-data-dir',
        help='input directory path where all the training file(s) reside in',
        type=str,
        dest='input_data_path',
        default=os.path.join(_DEFAULT_PREFIX_PATH, 'input/data/training')
    )
    required.add_argument(
        '-m', '--model-save-dir',
        help='directory path to save your model(s)',
        type=str,
        dest='model_save_path',
        default=os.path.join(_DEFAULT_PREFIX_PATH, 'model')
    )
    parser.add_argument(
        '-p', '--hyperparams-json-file',
        help='input path to hyperparams json file',
        type=str,
        default=os.path.join(_DEFAULT_PREFIX_PATH, 'input/config/hyperparameters.json'),
        dest='hyperparams_path'
    )
    parser.add_argument(
        '-f', '--failure-dir',
        help='output directory path to save your failure(s) files',
        type=str,
        default=os.path.join(_DEFAULT_PREFIX_PATH, 'failure'),
        dest='failure_output'
    )

    return parser.parse_args()


def train(input_data_path, model_save_path, hyperparams_path=None, failure_output=None):
    """
    The function to execute the training.

    :param input_data_path: [str], input directory path where all the training file(s) reside in
    :param model_save_path: [str], directory path to save your model(s)
    :param hyperparams_path: [optional[str], default=None], input path to hyperparams json file.
    Example:
        {
            "max_leaf_nodes": 10,
            "n_estimators": 200
        }
    :param failure_output: [optional[str], default=None], output directory path to save your
    failure(s) files
    """
    print('Starting the training.')
    try:
        train_function(
            input_data_path=input_data_path,
            model_save_path=model_save_path,
            hyperparams_path=hyperparams_path
        )
        print('Training complete.')
    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        if failure_output:
            with open(os.path.join(failure_output, 'failure'), 'w') as s:
                s.write('Exception during training: ' + str(e) + '\n' + trc)
            # Printing this causes the exception to be in the training job logs, as well.
            print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr)

        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)

if __name__ == '__main__':
    options = _parse_args()
    train(
        options.input_data_path,
        options.model_save_path,
        options.hyperparams_path,
        options.failure_output
    )

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)
