#!/usr/bin/env python
from __future__ import print_function
import sys
import psycopg2
import psycopg2.extras
import json
import records

from pg_elastic import replicate_es


class PgReplicateElastic(object):
    """ Main entry class """

    def __init__(self, args):
        super(PgReplicateElastic, self).__init__()
        self.args = args
        self.config = None

    def _elastic_consumer(self, msg):
        payload = json.loads(msg.payload)
        try:
            self.es.replicate(payload)
        except Exception as e:
            print(e)

        msg.cursor.send_feedback(flush_lsn=msg.data_start)

    def _validate_args(self):
        if len(self.args) == 0 or len(self.args) >= 2:
            print('Usage --config=<absolute path to config>')
            return False
        else:
            self.args = self.args[0]

        if '--config=' not in self.args:
            print('Usage --config=<absolute path to config>')
            return False
        else:
            self.args = self.args.split('=')[1]
            return True

    def _validate_config(self):
        valid_keys = [
            u'replication_slot',
            u'postgres',
            u'es_connection',
            u'inital_sync',
            u'tables'
        ]

        replication_slot_valid_keys = [
            u'name',
            u'is_temp'
        ]

        postgres_valid_keys = [
            u'database',
            u'host',
            u'username',
            u'password',
            u'port'
        ]

        tables_valid_keys = [
            u'name',
            u'uprimary_key'
        ]

        if self.config.keys().sort() != valid_keys.sort():
            print('Invalid configuration')
            return False

        if self.config['replication_slot'].keys().sort() != replication_slot_valid_keys.sort():
            print('Invalid replication slot configuration')
            return False

        for table in self.config['tables']:
            if table.keys().sort() != tables_valid_keys.sort():
                print('Each table in tables configuration must have name and primary_key')
                return False
        return True

    def _load_config(self):
        with open(self.args, 'r') as f:
            self.config = json.load(f)

        if self._validate_config():
            return True

        print('>>> Init failed!')

    def run(self):
        if not self._validate_args():
            print('Invalid arguments')
            return

        try:
            self._load_config()
        except:
            print('Configuration file does not exist!')
            return

        if not self._validate_config():
            print('Invalid configuration file format')
            return

        psycopg2_connection_string = 'dbname=%s host=%s user=%s password=%s' % (
            self.config['postgres']['database'],
            self.config['postgres']['host'],
            self.config['postgres']['username'],
            self.config['postgres']['password']
        )

        conn = psycopg2.connect(
            psycopg2_connection_string,
            connection_factory=psycopg2.extras.LogicalReplicationConnection
        )

        inital_sync = self.config['inital_sync']

        self.es = replicate_es.ElasticRepliaction(
            self.config['tables'],
            connection=self.config['es_connection']
        )

        cur = conn.cursor()

        if inital_sync:
            def inital_sync_table(table):
                print('Synchronization of %s....' % table['name'])
                connection_string = 'postgres://%s:%s/%s:%d/%s' % (
                    self.config['postgres']['username'],
                    self.config['postgres']['password'],
                    self.config['postgres']['host'],
                    self.config['postgres']['port'],
                    self.config['postgres']['database']
                )
                temp_client = records.Database(connection_string)
                rows = temp_client.query("SELECT * FROM %s" % table['name'])
                es.replicate(
                    rows.all(),
                    initial=inital_sync,
                    initial_table=table['name']
                )

            map(inital_sync_table, self.config['tables'])

        try:
            cur.start_replication(slot_name=self.config['replication_slot']['name'], decode=True)
        except psycopg2.ProgrammingError:
            if self.config['replication_slot']['is_temp']:
                slot_sql = "SELECT * FROM pg_create_logical_replication_slot('%s', 'wal2json', TRUE);" % self.config['replication_slot']['name']
                cur.execute(slot_sql)
                conn.commit()
            else:
                cur.create_replication_slot(self.config['replication_slot']['name'], output_plugin='wal2json')
            cur.start_replication(slot_name=self.config['replication_slot']['name'], decode=True)

        print("Starting streaming, press Control-C to end...", file=sys.stderr)

        try:
            cur.consume_stream(self._elastic_consumer)
        except KeyboardInterrupt:
            cur.close()
            conn.close()
            if not self.config['replication_slot']['is_temp']:
                print("\nWARNING: Transaction logs will accumulate in pg_xlog until the slot is dropped."
                      "\nThe slot 'elasticsearch_slot' still exists. Drop it with "
                      "SELECT pg_drop_replication_slot('%s'); if no longer needed." % self.config['replication_slot']['name'],
                      file=sys.stderr)

if __name__ == "__main__":
    arguments = sys.argv[1:]
    PgReplicateElastic(arguments).run()
