#!/usr/bin/python3
# -*- coding: utf-8 -*-
#

# Copyright (C) 2011 Mattias Slabanja <slabanja@chalmers.se>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.

import os
import sys
import optparse
import logging
import numpy as np

from itertools import count, islice
from functools import partial
from collections import deque

import dsf.filon as filon
from dsf.binner import fixed_bin_averager
from dsf.handythread import foreach
from dsf.index import section_index
from dsf.output import create_mfile, create_pfile
from dsf.reciprocal import reciprocal_isotropic, reciprocal_line
from dsf.trajectory import get_itraj, iwindow

from multiprocessing import cpu_count
from scipy.interpolate import interp1d


class averager:
    """Naive special purpose averager class used in dynasor

    It assists with keeping track on how many data samples
    have been added to each slot.

    Ex:
    av = averager(2)
    av[0] += 10
    av[0] += 2
    av[1] += 3
    av.get_av() ->
    [6, 3]
    """
    def __init__(self, N_slots, initial=np.zeros(1)):
        assert(N_slots >= 1)
        self._N = N_slots
        self._data = [np.array(initial) for n in range(N_slots)]
        self._samples = np.zeros(N_slots)

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
        self._data[key] = val
        self._samples[key] += 1

    def add(self, array, slot):
        self[slot] += array

    def get_single_av(self, slot):
        f = 1.0 / self._samples[slot]
        return f * self._data[slot]

    def get_av(self):
        return np.array([self.get_single_av(i) for i in range(self._N)])


hbar = 6.58211928e-1  # eV fs
pi = np.pi
two_pi = 2.0 * pi

if __name__ == '__main__':

    parser = optparse.OptionParser()

    iogroup = optparse.OptionGroup(parser,
                                   'Input/output options',
                                   'Options controlling input and output, '
                                   'files and fileformats.')
    iogroup.add_option('-f', '--trajectory', metavar='TRAJECTORY_FILE',
                       help='Molecular dynamics TRAJECTORY_FILE to be '
                       'analyzed. '
                       'Supported formats depends on VMD\'s molfile plugin '
                       'or gmxlib. As a fallback, a lammps-trajectory parser '
                       'implemented in Python is also available.')
    iogroup.add_option('-n', '--index', metavar='INDEX_FILE',
                       help='Optional index file (think Gromacs NDX-style) for'
                       ' specifying atom types. Atoms are indexed from 1 up to'
                       ' N (total number of atoms). '
                       ' It is possible to index only a sub set of all atoms,'
                       ' and atoms can be indexed in more than one group.'
                       ' If no INDEX_FILE is provided, all atoms will be'
                       ' considered identical.')
    iogroup.add_option('', '--om', metavar='FILE',
                       help='Write output to FILE as a Matlab style m-file.')
    iogroup.add_option('', '--op', metavar='FILE',
                       help='Write output to FILE as a Python pickled file.')
    parser.add_option_group(iogroup)

    kspace = optparse.OptionGroup(parser,
                                  'General k-space options',
                                  'Options controlling general aspects for'
                                  ' how k-space should be sampled and'
                                  ' collected.')
    kspace.add_option('', '--k-sampling',
                      metavar='STYLE', default='isotropic',
                      help='Possible values are "isotropic" (default) '
                      'for sampling isotropic systems (as liquids), '
                      '"line" to sample uniformely along a certain direction '
                      'in k-space, '
                      '"explicit" for sampling on an explicit set of k-points')
    kspace.add_option('', '--k-bins',
                      metavar='BINS', type='int', default=80,
                      help='Number of "radial" bins to use (between 0 and '
                      'largest |k|-value) when collecting resulting '
                      'average. Default value is 80.')
    parser.add_option_group(kspace)

    kiso = optparse.OptionGroup(parser,
                                'Isotropic k-space sampling')
    kiso.add_option('', '--max-k-points', metavar='KPOINTS', type='int',
                    default=20000,
                    help='Maximum number of points used to sample k-space. '
                    'Puts an (approximate) upper limit by randomly selecting. '
                    'points. Default value is 20000.')
    kiso.add_option('', '--k-max', metavar='KMAX', type='float', default=60,
                    help='Largest k-value to consider (in "2*pi*nm^-1"). '
                    'Default value for KMAX is 60. ')
    parser.add_option_group(kiso)

    kline = optparse.OptionGroup(parser,
                                 'Line-style k-space sampling')
    kline.add_option('', '--k-direction', metavar='KDIRECTION',
                     help='Point in k-space to consider. '
                     'KPOINTS points will be evenly placed between '
                     '0,0,0 and KDIRECTION. Given as three '
                     'comma separated values.')
    kline.add_option('', '--k-points', metavar='KPOINTS', type='int',
                     default=100,
                     help='Numbers of uniformly distributed points between '
                     '0,0,0 and KDIRECTION to consider. Default value is 100.')
    parser.add_option_group(kline)

    # kexpl = optparse.OptionGroup(parser,
    #                             'Explicit k-space sampling')
    # kexpl.add_option('','--k-points-file', metavar='KPOINTS-FILE',
    #                  help='KPOINTS-FILE should contain each kpoint to '
    #                  'consider. ')
    # parser.add_option_group(kexpl)

    tgroup = optparse.OptionGroup(parser, 'Time-related options',
                                  'Options controlling timestep,'
                                  ' length and shape of trajectory frame'
                                  ' window, etc.')
    tgroup.add_option('', '--nt', metavar='TIME_CORR_STEPS', type='int',
                      help='Determines the length of the trajectory frame '
                      'window to use for time correlation calculation. '
                      'TIME_CORR_STEPS is expressed in "number of frames" and '
                      'e.g. determines the smallest frequency resolved. '
                      'If no TIME_CORR_STEPS is provided, only static (t=0) '
                      'correlations will be calculated')
    tgroup.add_option('', '--max-frames', metavar='FRAMES', type='int',
                      default=100,
                      help='Limits the total number of trajectory frames read '
                      'to FRAMES. '
                      'The default value is set to 100.')
    tgroup.add_option('', '--step', metavar='STEP', type='int', default=1,
                      help='Only use every (STEP)th trajectory frame.'
                      ' Default STEP is 1, meaning every frame is processed.'
                      ' STEP affects dt and hence the smallest time resolved.')
    tgroup.add_option('', '--stride', metavar='STRIDE', type='int', default=1,
                      help='Stride STRIDE frames between consecutive'
                      ' trajectory windows. This does not affect dt.'
                      ' If e.g. STRIDE > TIME_CORR_STEPS, some frames will'
                      ' be completely unused.')
    tgroup.add_option('', '--dt', metavar='DELTATIME',
                      help='Explicitly sets the time difference between two '
                      'consecutively processed trajectory frames to DELTATIME'
                      ' (femtoseconds). Useful when no time step information'
                      ' can be extracted from the trajectory file (e.g. when'
                      ' using molfileplugin).')
    parser.add_option_group(tgroup)

    options = optparse.OptionGroup(parser, 'General processing options')
    options.add_option('', '--calculate-self',
                       action='store_true', default=False,
                       help='Calculate the self-part, F_s, ...')
    parser.add_option_group(options)

    parser.add_option('', '--threads', type='int', default=0,
                      help='Number of threads to use. '
                      'The default value is taken from OMP_NUM_THREADS if it'
                      ' is set,  otherwise it is set to the number of'
                      ' available cpus.')
    parser.add_option('-q', '--quiet', action='count', default=0,
                      help='Increase quietness (opposite of verbosity).')
    parser.add_option('-v', '--verbose', action='count', default=0,
                      help='Increase verbosity (opposite of quietness).')

    options, args = parser.parse_args()
    quietness = options.quiet - options.verbose

    logger = logging.getLogger('dynasor')
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter(r'%(levelname)s: %(message)s'))
    logger.addHandler(handler)
    if quietness < 0:
        logger.setLevel(logging.DEBUG)
    elif quietness == 0:
        logger.setLevel(logging.INFO)
    elif quietness == 1:
        logger.setLevel(logging.WARN)
    elif quietness == 2:
        logger.setLevel(logging.ERROR)
    else:
        logger.setLevel(logging.CRITICAL)

    if options.trajectory is None:
        logger.error('A trajectory must be specified. Use option -f.\n')
        parser.print_help()
        sys.exit(1)

    if not options.om and not options.op:
        logger.warn('No output file specified. Results will not be saved')

    if options.threads != 0:
        num_threads = options.threads
    elif 'OMP_NUM_THREADS' in os.environ:
        num_threads = int(os.environ['OMP_NUM_THREADS'])
    else:
        num_threads = cpu_count()
    if num_threads < 1:
        logger.error('Number of threads must be > 0')
        sys.exit(1)

    # Affects rho_j_k
    os.environ['OMP_NUM_THREADS'] = str(num_threads)

    # Read the two first frames to set up references values
    # box size, number of different particles, time step length, etc
    try:
        f0, f1 = islice(get_itraj(options.trajectory, step=options.step),
                        2)
    except ValueError:
        logger.error('Failed to read two consecutive frames to determine '
                     'delta t. Is the trajectory long enough?')
        sys.exit(1)

    delta_t = f1['time'] - f0['time']

    if options.dt:
        try:
            delta_t = float(options.dt)
        except ValueError:
            logger.error('Could not parse provided dt ({}) as float!'
                         .format(options.dt))
            sys.exit(1)
        logger.warn('-- Explicitly setting delta time to {} fs.'
                    .format(delta_t))

    elif delta_t == 0.0:
        logger.warn('Delta time is 0.0 (meaning, cannot read info'
                    ' from trajectory)!')
        logger.warn('-- Defaulting to 1 fs'
                    ' [this will affect the output scale]')
        delta_t = 1

    if options.nt:
        assert(options.nt >= 1)
        N_tc = options.nt + (options.nt + 1) % 2
    else:
        N_tc = 1

    if N_tc > 1:
        logger.info('-- delta_t found to be {} [fs], time window {} [fs]'
                    .format(delta_t, delta_t * (options.nt - 1)))
        dw = two_pi / (options.nt * delta_t)
        logger.info('-- delta_omega, max_omega = {}, {} [fs^-1]'
                    .format(dw, dw * options.nt))

    # Do we get any velocities from the trajectory reader?
    if 'v' in f0:
        calculate_current = True
    else:
        calculate_current = False

    # Should the particle self correlations be calculated?
    if options.calculate_self:
        calculate_self = True
    else:
        calculate_self = False

    index = section_index(options.index, f0['N'])

    a, b, c = reference_box = f0['box']
    reference_volume = abs(np.dot(np.cross(a, b), c))
    particle_types = index.get_section_names()
    particle_counts = list(deque(map(len, index.get_section_indices())))
    particle_densities = [n / reference_volume for n in particle_counts]
    logger.info('Trajectory file: %s' % options.trajectory)
    logger.info('-- With a total of %i particles, %i types.' % (
            f0['N'], len(particle_types)))
    for i, t in enumerate(particle_types):
        logger.info('-- %d. %d %s' % (i + 1, particle_counts[i], t))

    logger.info('Simulation box is\n%s' % str(reference_box))

    style = options.k_sampling
    if style not in ('isotropic', 'line'):
        logger.error('Unknown style %s' % style)
        sys.exit(1)

    if style == 'line':
        # Sample on points along a line in k-space
        if options.k_direction is None:
            logger.error('k-direction must be specified in line mode.')
            sys.exit(1)

        try:
            k_direction = np.array(deque(map(float, options.k_direction.split(','))))
        except ValueError:
            logger.error('k-direction must be specified as a valid comma '
                         'separated three-tuple.')
            sys.exit(1)

        rec = reciprocal_line(points=options.k_points,
                              k_direction=k_direction)

    elif style == 'isotropic':
        # Sample k-space without preference to direction
        rec = reciprocal_isotropic(reference_box,
                                   max_points=options.max_k_points,
                                   max_k=options.k_max)

    if len(rec.k_distance) > 1:
        logger.info('N kpoints = %i' % len(rec.k_distance))
    else:
        logger.warning('N kpoints = %i' % len(rec.k_distance))
    logger.info('k_max = {} --> x_min = {} [nm]'
                .format(rec.max_k, two_pi / rec.max_k))

    assert options.stride > 0
    N_stride = options.stride

    # function to use to "calculate rho(k)"
    f2 = rec.get_frame_process_function()
    # function to split particles into different index groups (types)
    f1 = index.get_section_split_function()  # Prerequisite for f2
    # apply this to each frame considered
    element_processor = lambda frame: f2(f1(frame))

    # The trajectory window iterator
    itraj_window = iwindow(get_itraj(options.trajectory,
                                     step=options.step,
                                     max_frames=options.max_frames),
                           width=N_tc,
                           stride=options.stride,
                           element_processor=element_processor)

    # TODO....
    # * Assert box is not changed during consecutive frames

    Ntypes = len(particle_types)

    m = count(0)
    pair_list = [(m.__next__(), i, j) for i in range(Ntypes) for j in range(i, Ntypes)]
    pair_types = [particle_types[i] + '-' + particle_types[j]
                  for _, i, j in pair_list]

    z = np.zeros(len(rec.q_distance))
    F_k_t_avs = [averager(N_tc, z) for _ in pair_list]
    if calculate_current:
        Cl_k_t_avs = [averager(N_tc, z) for _ in pair_list]
        Ct_k_t_avs = [averager(N_tc, z) for _ in pair_list]
    if calculate_self:
        F_s_k_t_avs = [averager(N_tc, z) for _ in particle_types]

    def calc_corr(window, time_i):
        # Calculate correlations between two frames in the window
        f0 = window[0]
        fi = window[time_i]
        for m, i, j in pair_list:
            F_k_t_avs[m][time_i] += np.real(f0['rho_ks'][i]
                                            * fi['rho_ks'][j].conjugate())

        if calculate_current:
            for m, i, j in pair_list:
                Cl_k_t_avs[m][time_i] += np.real(f0['jz_ks'][i]
                                                 * fi['jz_ks'][j].conjugate())
                Ct_k_t_avs[m][time_i] += 0.5 * \
                    np.real(np.sum(f0['jper_ks'][i]
                                   * fi['jper_ks'][j].conjugate(), axis=0))

        if calculate_self:
            for i, F_s in enumerate(
                rec.process_specific_xs(
                    [(xi - x0) for xi, x0 in zip(fi['xs'], f0['xs'])])):
                F_s_k_t_avs[i][time_i] += np.real(F_s)

    # This is the "main loop"
    for window in itraj_window:
        logger.debug("processing window step %i to %i" % (window[0]['index'],
                                                          window[-1]['index']))
        # Have num_threads threads concurrently process the window
        foreach(partial(calc_corr, window),
                range(len(window)), threads=num_threads)

    # Extract correlation (all k-point) averages
    # and calculate average per 'radial' bin
    k_bins = fixed_bin_averager(rec.max_k, options.k_bins, rec.k_distance)
    k_bin_averager = partial(k_bins.bin, axis=1)

    F_k_t = list(map(k_bin_averager, [F.get_av() for F in F_k_t_avs]))

    if calculate_current:
        Cl_k_t = list(map(k_bin_averager, [C.get_av() for C in Cl_k_t_avs]))
        Ct_k_t = list(map(k_bin_averager, [C.get_av() for C in Ct_k_t_avs]))

    if calculate_self:
        F_s_k_t = list(map(k_bin_averager, [C.get_av() for C in F_s_k_t_avs]))
        for i, N in enumerate(particle_counts):
            F_s_k_t[i] *= (1.0 / np.sqrt(N))

    t = delta_t * np.arange(N_tc)
    k = k_bins.x.copy()
    k_bin_count = k_bins.bin_count.copy()

    output = []
    output += [(k, 'k', 'k-values (technically, bin centers) [nm^1]'),
               (t, 't', 'time values [fs]'),
               (k_bin_count, 'k_bin_count', 'Number of k-points per bin')]
    output += [(F_k_t[m], 'F_k_t_%i_%i' % (i, j),
                'Partial intermediate scattering function [time, k] ({})'
                .format(pair_types[m]))
               for m, i, j in pair_list]

    if calculate_current:
        output += [(Cl_k_t[m], 'Cl_k_t_%i_%i' % (i, j),
                    'Longitudinal current correlation [time, k] ({})'
                    .format(pair_types[m])) for m, i, j in pair_list]
        output += [(Ct_k_t[m], 'Ct_k_t_%i_%i' % (i, j),
                    'Transversal current correlation [time, k] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]

    if calculate_self:
        output += [(F_s_k_t[i], 'F_s_k_t_%i' % i,
                    'Self part of intermediate scattring function [time, k]')
                   for i in range(index.N_sections())]

    if len(k) > 1:
        # Create an odd number of linearly spaced k-points, ranging from
        # the "distance" of the smallest non-empty bin and up.
        k_ = k_bins.x_linspace
        k_ = k_[k_ >= k[1]]
        k_ = k_[k_ <= k[-1]]
        if not len(k_) % 2:
            k_ = k_[:-1]

        dr = two_pi / k[-1]
        r = np.arange(5 * dr, pi / k[1], dr)

        def F_to_G(F, pair_index):
            _, i, j = pair_list[pair_index]
            f = 1 / (r * 2 * pi ** 2 * particle_densities[j])
            kF_ = k_ * interp1d(k, F - 1)(k_)
            return f * filon.sin_integral(kF_, k_[1] - k_[0], r,
                                          k_[0], axis=1) + 1
        G_r_t = [F_to_G(F, i) for i, F in enumerate(F_k_t)]

        output += [(r, 'r', 'r-values for calculated G(r,t) [nm]')]
        output += [(G_r_t[m], 'G_r_t_%i_%i' % (i, j),
                    'Calculated partial van Hove function [time, r] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]

    if len(t) > 2:
        w, S_k_w = zip(*[filon.fourier_cos(F, delta_t) for F in F_k_t])
        w = w[0]
        output += [(w, 'w', 'omega [fs^-1]')]
        output += [(S_k_w[m], 'S_k_w_%i_%i' % (i, j),
                    'Partial dynamical structure factor [omega, k] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]

        if calculate_current:
            _, Cl_k_w = zip(*[filon.fourier_cos(C, delta_t) for C in Cl_k_t])
            _, Ct_k_w = zip(*[filon.fourier_cos(C, delta_t) for C in Ct_k_t])
            output += [(Cl_k_w[m], 'Cl_k_w_%i_%i' % (i, j),
                        'Longitudinal partial current correlation'
                        ' [omega, k] ({})'.format(pair_types[m]))
                       for m, i, j in pair_list]
            output += [(Ct_k_w[m], 'Ct_k_w_%i_%i' % (i, j),
                        'Transversal partial current correlation'
                        ' [omega, k] ({})'.format(pair_types[m]))
                       for m, i, j in pair_list]

        if calculate_self:
            _, S_s_k_w = zip(*[filon.fourier_cos(F, delta_t) for F in F_s_k_t])
            output += [(S_s_k_w[i], 'S_s_k_w_%i' % i,
                        'Self part of partial dynamical structure factor'
                        ' [omega, k]')
                       for i, _ in enumerate(particle_types)]

    comment = 'Command line: ' + ' '.join(sys.argv)
    for fn, writer in ((options.om, partial(create_mfile, comment=comment)),
                       (options.op, create_pfile)):
        fn and writer(fn, output)
