################################################################################
#                                                                              #
#  Copyright 2015-2019 Max Planck Institute for Dynamics and Self-Organization #
#                                                                              #
#  This file is part of TurTLE.                                                  #
#                                                                              #
#  TurTLE is free software: you can redistribute it and/or modify                #
#  it under the terms of the GNU General Public License as published           #
#  by the Free Software Foundation, either version 3 of the License,           #
#  or (at your option) any later version.                                      #
#                                                                              #
#  TurTLE is distributed in the hope that it will be useful,                     #
#  but WITHOUT ANY WARRANTY; without even the implied warranty of              #
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
#  GNU General Public License for more details.                                #
#                                                                              #
#  You should have received a copy of the GNU General Public License           #
#  along with TurTLE.  If not, see <http://www.gnu.org/licenses/>                #
#                                                                              #
# Contact: Cristian.Lalescu@ds.mpg.de                                          #
#                                                                              #
################################################################################



import os
import sys
import shutil
import subprocess
import argparse
import h5py
import math
import numpy as np
import warnings

import TurTLE
from ._code import _code
from TurTLE import tools

class DNS(_code):
    """This class is meant to stitch together the C++ code into a final source file,
    compile it, and handle all job launching.
    """
    def __init__(
            self,
            work_dir = './',
            simname = 'test'):
        _code.__init__(
                self,
                work_dir = work_dir,
                simname = simname)
        self.generate_default_parameters()
        self.statistics = {}
        return None
    def set_precision(
            self,
            fluid_dtype):
        if fluid_dtype in [np.float32, np.float64]:
            self.fluid_dtype = fluid_dtype
        elif fluid_dtype in ['single', 'double']:
            if fluid_dtype == 'single':
                self.fluid_dtype = np.dtype(np.float32)
            elif fluid_dtype == 'double':
                self.fluid_dtype = np.dtype(np.float64)
        self.rtype = self.fluid_dtype
        if self.rtype == np.float32:
            self.ctype = np.dtype(np.complex64)
            self.C_field_dtype = 'float'
            self.fluid_precision = 'single'
        elif self.rtype == np.float64:
            self.ctype = np.dtype(np.complex128)
            self.C_field_dtype = 'double'
            self.fluid_precision = 'double'
        return None
    def write_src(
            self):
        self.version_message = (
                '/***********************************************************************\n' +
                '* this code automatically generated by TurTLE\n' +
                '* version {0}\n'.format(TurTLE.__version__) +
                '***********************************************************************/\n\n\n')
        self.include_list = [
                '"base.hpp"',
                '"scope_timer.hpp"',
                '"fftw_interface.hpp"',
                '"full_code/main_code.hpp"',
                '<cmath>',
                '<iostream>',
                '<hdf5.h>',
                '<string>',
                '<cstring>',
                '<fftw3-mpi.h>',
                '<omp.h>',
                '<cfenv>',
                '<cstdlib>',
                '"full_code/{0}.hpp"\n'.format(self.dns_type)]
        self.main = """
            int main(int argc, char *argv[])
            {{
                bool fpe = (
                    (getenv("TURTLE_FPE_OFF") == nullptr) ||
                    (getenv("TURTLE_FPE_OFF") != std::string("TRUE")));
                return main_code< {0} >(argc, argv, fpe);
            }}
            """.format(self.dns_type + '<{0}>'.format(self.C_field_dtype))
        self.includes = '\n'.join(
                ['#include ' + hh
                 for hh in self.include_list])
        with open(self.name + '.cpp', 'w') as outfile:
            outfile.write(self.version_message + '\n\n')
            outfile.write(self.includes + '\n\n')
            outfile.write(self.main + '\n')
        return None
    def generate_default_parameters(self):
        # these parameters are relevant for all DNS classes
        self.parameters['fftw_plan_rigor'] = 'FFTW_ESTIMATE'
        self.parameters['dealias_type'] = int(1)
        self.parameters['dkx'] = float(1.0)
        self.parameters['dky'] = float(1.0)
        self.parameters['dkz'] = float(1.0)
        self.parameters['niter_todo'] = int(8)
        self.parameters['niter_stat'] = int(1)
        self.parameters['niter_out'] = int(8)
        self.parameters['checkpoints_per_file'] = int(1)
        self.parameters['dt'] = float(0.01)
        self.parameters['nu'] = float(0.1)
        self.parameters['fmode'] = int(1)
        self.parameters['famplitude'] = float(0.5)
        self.parameters['friction_coefficient'] = float(0.5)
        self.parameters['energy'] = float(0.5)
        self.parameters['injection_rate'] = float(0.4)
        self.parameters['fk0'] = float(2.0)
        self.parameters['fk1'] = float(4.0)
        self.parameters['forcing_type'] = 'fixed_energy_injection_rate'
        self.parameters['histogram_bins'] = int(256)
        self.parameters['max_velocity_estimate'] = float(1)
        self.parameters['max_vorticity_estimate'] = float(1)
        # parameters specific to particle version
        self.NSVEp_extra_parameters = {}
        self.NSVEp_extra_parameters['niter_part'] = int(1)
        self.NSVEp_extra_parameters['niter_part_fine_period'] = int(10)
        self.NSVEp_extra_parameters['niter_part_fine_duration'] = int(0)
        self.NSVEp_extra_parameters['nparticles'] = int(10)
        self.NSVEp_extra_parameters['tracers0_integration_steps'] = int(4)
        self.NSVEp_extra_parameters['tracers0_neighbours'] = int(1)
        self.NSVEp_extra_parameters['tracers0_smoothness'] = int(1)
        self.NSVEp_extra_parameters['tracers0_enable_p2p'] = int(0)
        self.NSVEp_extra_parameters['tracers0_enable_inner'] = int(0)
        self.NSVEp_extra_parameters['tracers0_enable_vorticity_omega'] = int(0)
        self.NSVEp_extra_parameters['tracers0_cutoff'] = float(1)
        self.NSVEp_extra_parameters['tracers0_inner_v0'] = float(1)
        self.NSVEp_extra_parameters['tracers0_lambda'] = float(1)
        #self.extra_parameters = {}
        #for key in ['NSVE', 'NSVE_no_output', 'NSVEparticles', 'NSVEparticles_no_output', 'NSVEcomplex_particles']:
        #    self.extra_parameters[key] = {}
        #for key in ['NSVEparticles', 'NSVEparticles_no_output', 'NSVEcomplex_particles']:
        #    self.extra_parameters[key].update(self.NSVEp_extra_parameters)
        return None
    def get_kspace(self):
        kspace = {}
        if self.parameters['dealias_type'] == 1:
            kMx = self.parameters['dkx']*(self.parameters['nx']//2 - 1)
            kMy = self.parameters['dky']*(self.parameters['ny']//2 - 1)
            kMz = self.parameters['dkz']*(self.parameters['nz']//2 - 1)
        else:
            kMx = self.parameters['dkx']*(self.parameters['nx']//3 - 1)
            kMy = self.parameters['dky']*(self.parameters['ny']//3 - 1)
            kMz = self.parameters['dkz']*(self.parameters['nz']//3 - 1)
        kspace['kM'] = max(kMx, kMy, kMz)
        kspace['dk'] = min(self.parameters['dkx'],
                           self.parameters['dky'],
                           self.parameters['dkz'])
        nshells = int(kspace['kM'] / kspace['dk']) + 2
        kspace['nshell'] = np.zeros(nshells, dtype = np.int64)
        kspace['kshell'] = np.zeros(nshells, dtype = np.float64)
        kspace['kx'] = np.arange( 0,
                                  self.parameters['nx']//2 + 1).astype(np.float64)*self.parameters['dkx']
        kspace['ky'] = np.arange(-self.parameters['ny']//2 + 1,
                                  self.parameters['ny']//2 + 1).astype(np.float64)*self.parameters['dky']
        kspace['ky'] = np.roll(kspace['ky'], self.parameters['ny']//2+1)
        kspace['kz'] = np.arange(-self.parameters['nz']//2 + 1,
                                  self.parameters['nz']//2 + 1).astype(np.float64)*self.parameters['dkz']
        kspace['kz'] = np.roll(kspace['kz'], self.parameters['nz']//2+1)
        return kspace
    def get_data_file_name(self):
        return os.path.join(self.work_dir, self.simname + '.h5')
    def get_data_file(self):
        return h5py.File(self.get_data_file_name(), 'r')
    def get_particle_file_name(self):
        return os.path.join(self.work_dir, self.simname + '_particles.h5')
    def get_particle_file(self):
        return h5py.File(self.get_particle_file_name(), 'r')
    def get_cache_file_name(self):
        return os.path.join(self.work_dir, self.simname + '_cache.h5')
    def get_cache_file(self):
        return h5py.File(self.get_cache_file_name(), 'r')
    def get_postprocess_file_name(self):
        return self.get_cache_file_name()
    def get_postprocess_file(self):
        return h5py.File(self.get_postprocess_file_name(), 'r')
    def compute_statistics(self, iter0 = 0, iter1 = None):
        """Run basic postprocessing on raw data.
        The energy spectrum :math:`E(t, k)` and the enstrophy spectrum
        :math:`\\frac{1}{2}\omega^2(t, k)` are computed from the

        .. math::

            \sum_{k \\leq \\|\\mathbf{k}\\| \\leq k+dk}\\hat{u_i} \\hat{u_j}^*, \\hskip .5cm
            \sum_{k \\leq \\|\\mathbf{k}\\| \\leq k+dk}\\hat{\omega_i} \\hat{\\omega_j}^*

        tensors, and the enstrophy spectrum is also used to
        compute the dissipation :math:`\\varepsilon(t)`.
        These basic quantities are stored in a newly created HDF5 file,
        ``simname_cache.h5``.
        """
        if len(list(self.statistics.keys())) > 0:
            return None
        if not os.path.exists(self.get_data_file_name()):
            if os.path.exists(self.get_cache_file_name()):
                self.read_parameters(fname = self.get_cache_file_name())
                pp_file = self.get_cache_file()
                for k in ['t',
                          'energy(t)',
                          'energy(k)',
                          'enstrophy(t)',
                          'enstrophy(k)',
                          'R_ij(t)',
                          'vel_max(t)',
                          'renergy(t)',
                          'renstrophy(t)']:
                    if k in pp_file.keys():
                        self.statistics[k] = pp_file[k][...]
                self.statistics['kM'] = pp_file['kspace/kM'][...]
                self.statistics['dk'] = pp_file['kspace/dk'][...]
                self.statistics['kshell'] = pp_file['kspace/kshell'][...]
                self.statistics['nshell'] = pp_file['kspace/nshell'][...]
        else:
            self.read_parameters()
            with self.get_data_file() as data_file:
                if 'moments' not in data_file['statistics'].keys():
                    return None
                iter0 = min((data_file['statistics/moments/velocity'].shape[0] *
                             self.parameters['niter_stat']-1),
                            iter0)
                if type(iter1) == type(None):
                    iter1 = data_file['iteration'][...]
                else:
                    iter1 = min(data_file['iteration'][...], iter1)
                ii0 = iter0 // self.parameters['niter_stat']
                ii1 = iter1 // self.parameters['niter_stat']
                self.statistics['kshell'] = data_file['kspace/kshell'][...]
                self.statistics['nshell'] = data_file['kspace/nshell'][...]
                for kk in [-1, -2]:
                    if (self.statistics['kshell'][kk] == 0):
                        self.statistics['kshell'][kk] = np.nan
                self.statistics['kM'] = data_file['kspace/kM'][...]
                self.statistics['dk'] = data_file['kspace/dk'][...]
                computation_needed = True
                pp_file = h5py.File(self.get_postprocess_file_name(), 'a')
                if not ('parameters' in pp_file.keys()):
                    data_file.copy('parameters', pp_file)
                    data_file.copy('kspace', pp_file)
                if 'ii0' in pp_file.keys():
                    computation_needed =  not (ii0 == pp_file['ii0'][...] and
                                               ii1 == pp_file['ii1'][...])
                    if computation_needed:
                        for k in ['t', 'vel_max(t)',
                                  'renergy(t)',
                                  'renstrophy(t)',
                                  'energy(t)', 'enstrophy(t)',
                                  'energy(k)', 'enstrophy(k)',
                                  'energy(t, k)',
                                  'enstrophy(t, k)',
                                  'R_ij(t)',
                                  'ii0', 'ii1', 'iter0', 'iter1']:
                            if k in pp_file.keys():
                                del pp_file[k]
                if computation_needed:
                    #TODO figure out whether normalization is sane or not
                    pp_file['iter0'] = iter0
                    pp_file['iter1'] = iter1
                    pp_file['ii0'] = ii0
                    pp_file['ii1'] = ii1
                    pp_file['t'] = (self.parameters['dt']*
                                    self.parameters['niter_stat']*
                                    (np.arange(ii0, ii1+1).astype(np.float)))
                    # we have an extra division by shell_width because of the Dirac delta restricting integration to the shell
                    phi_ij = data_file['statistics/spectra/velocity_velocity'][ii0:ii1+1] / self.statistics['dk']
                    pp_file['R_ij(t)'] = np.sum(phi_ij*self.statistics['dk'], axis = 1)
                    energy_tk = (
                        phi_ij[:, :, 0, 0] +
                        phi_ij[:, :, 1, 1] +
                        phi_ij[:, :, 2, 2])/2
                    pp_file['energy(t)'] = np.sum(energy_tk*self.statistics['dk'], axis = 1)
                    # normalization factor is (4 pi * shell_width * kshell^2) / (nmodes in shell * dkx*dky*dkz)
                    norm_factor = (4*np.pi*self.statistics['dk']*self.statistics['kshell']**2) / (self.parameters['dkx']*self.parameters['dky']*self.parameters['dkz']*self.statistics['nshell'])
                    pp_file['energy(k)'] = np.mean(energy_tk, axis = 0)*norm_factor
                    phi_vorticity_ij = data_file['statistics/spectra/vorticity_vorticity'][ii0:ii1+1] / self.statistics['dk']
                    enstrophy_tk = (
                        phi_vorticity_ij[:, :, 0, 0] +
                        phi_vorticity_ij[:, :, 1, 1] +
                        phi_vorticity_ij[:, :, 2, 2])/2
                    pp_file['enstrophy(t)'] = np.sum(enstrophy_tk*self.statistics['dk'], axis = 1)
                    pp_file['enstrophy(k)'] = np.mean(enstrophy_tk, axis = 0)*norm_factor
                    pp_file['vel_max(t)'] = data_file['statistics/moments/velocity'][ii0:ii1+1, 9, 3]
                    pp_file['renergy(t)'] = data_file['statistics/moments/velocity'][ii0:ii1+1, 2, 3]/2
                    pp_file['renstrophy(t)'] = data_file['statistics/moments/vorticity'][ii0:ii1+1, 2, 3]/2
        for k in ['t',
                  'energy(t)',
                  'energy(k)',
                  'enstrophy(t)',
                  'enstrophy(k)',
                  'R_ij(t)',
                  'vel_max(t)',
                  'renergy(t)',
                  'renstrophy(t)']:
            if k in pp_file.keys():
                self.statistics[k] = pp_file[k][...]
        # sanity check --- Parseval theorem check
        assert(np.max(np.abs(
                self.statistics['renergy(t)'] -
                self.statistics['energy(t)']) / self.statistics['energy(t)']) < 1e-5)
        assert(np.max(np.abs(
                self.statistics['renstrophy(t)'] -
                self.statistics['enstrophy(t)']) / self.statistics['enstrophy(t)']) < 1e-5)
        self.compute_time_averages()
        return None
    def compute_Reynolds_stress_invariants(
            self):
        """
        see Choi and Lumley, JFM v436 p59 (2001)
        """
        Rij = self.statistics['R_ij(t)']
        Rij /= (2*self.statistics['energy(t)'][:, None, None])
        Rij[:, 0, 0] -= 1./3
        Rij[:, 1, 1] -= 1./3
        Rij[:, 2, 2] -= 1./3
        self.statistics['I2(t)'] = np.sqrt(np.einsum('...ij,...ij', Rij, Rij, optimize = True) / 6)
        self.statistics['I3(t)'] = np.cbrt(np.einsum('...ij,...jk,...ki', Rij, Rij, Rij, optimize = True) / 6)
        return None
    def compute_time_averages(self):
        """Compute easy stats.

        Further computation of statistics based on the contents of
        ``simname_cache.h5``.
        Standard quantities are as follows
        (consistent with [Ishihara]_):

        .. math::

            U_{\\textrm{int}}(t) = \\sqrt{\\frac{2E(t)}{3}}, \\hskip .5cm
            L_{\\textrm{int}} = \\frac{\pi}{2U_{int}^2} \\int \\frac{dk}{k} E(k), \\hskip .5cm
            T_{\\textrm{int}} =
            \\frac{L_{\\textrm{int}}}{U_{\\textrm{int}}}

            \\eta_K = \\left(\\frac{\\nu^3}{\\varepsilon}\\right)^{1/4}, \\hskip .5cm
            \\tau_K = \\left(\\frac{\\nu}{\\varepsilon}\\right)^{1/2}, \\hskip .5cm
            \\lambda = \\sqrt{\\frac{15 \\nu U_{\\textrm{int}}^2}{\\varepsilon}}

            Re = \\frac{U_{\\textrm{int}} L_{\\textrm{int}}}{\\nu}, \\hskip
            .5cm
            R_{\\lambda} = \\frac{U_{\\textrm{int}} \\lambda}{\\nu}

        .. [Ishihara] T. Ishihara et al,
                      *Small-scale statistics in high-resolution direct numerical
                      simulation of turbulence: Reynolds number dependence of
                      one-point velocity gradient statistics*.
                      J. Fluid Mech.,
                      **592**, 335-366, 2007
        """
        self.statistics['Uint(t)'] = np.sqrt(2*self.statistics['energy(t)'] / 3)
        for key in ['energy',
                    'enstrophy',
                    'mean_trS2',
                    'Uint']:
            if key + '(t)' in self.statistics.keys():
                self.statistics[key] = np.average(self.statistics[key + '(t)'], axis = 0)
        self.statistics['vel_max'] = np.max(self.statistics['vel_max(t)'])
        for suffix in ['', '(t)']:
            self.statistics['diss'    + suffix] = (self.parameters['nu'] *
                                                   self.statistics['enstrophy' + suffix]*2)
            self.statistics['etaK'    + suffix] = (self.parameters['nu']**3 /
                                                   self.statistics['diss' + suffix])**.25
            self.statistics['tauK'    + suffix] =  (self.parameters['nu'] /
                                                    self.statistics['diss' + suffix])**.5
            self.statistics['lambda' + suffix] = (15 * self.parameters['nu'] *
                                                  self.statistics['Uint' + suffix]**2 /
                                                  self.statistics['diss' + suffix])**.5
            self.statistics['Rlambda' + suffix] = (self.statistics['Uint' + suffix] *
                                                   self.statistics['lambda' + suffix] /
                                                   self.parameters['nu'])
            self.statistics['kMeta' + suffix] = (self.statistics['kM'] *
                                                 self.statistics['etaK' + suffix])
            if self.parameters['dealias_type'] == 1:
                self.statistics['kMeta' + suffix] *= 0.8
        self.statistics['Lint'] = ((np.pi /
                                    (2*self.statistics['Uint']**2)) *
                                   np.sum(self.statistics['energy(k)'][1:-2] /
                                          self.statistics['kshell'][1:-2]))
        self.statistics['Re'] = (self.statistics['Uint'] *
                                 self.statistics['Lint'] /
                                 self.parameters['nu'])
        self.statistics['Tint'] = self.statistics['Lint'] / self.statistics['Uint']
        self.statistics['Taylor_microscale'] = self.statistics['lambda']
        return None
    def set_plt_style(
            self,
            style = {'dashes' : (None, None)}):
        self.style.update(style)
        return None
    def convert_complex_from_binary(
            self,
            field_name = 'vorticity',
            iteration = 0,
            file_name = None):
        """read the Fourier representation of a vector field.

        Read the binary file containing iteration ``iteration`` of the
        field ``field_name``, and write it in a ``.h5`` file.
        """
        data = np.memmap(
                os.path.join(self.work_dir,
                             self.simname + '_{0}_i{1:0>5x}'.format('c' + field_name, iteration)),
                dtype = self.ctype,
                mode = 'r',
                shape = (self.parameters['ny'],
                         self.parameters['nz'],
                         self.parameters['nx']//2+1,
                         3))
        if type(file_name) == type(None):
            file_name = self.simname + '_{0}_i{1:0>5x}.h5'.format('c' + field_name, iteration)
            file_name = os.path.join(self.work_dir, file_name)
        f = h5py.File(file_name, 'a')
        f[field_name + '/complex/{0}'.format(iteration)] = data
        f.close()
        return None
    def write_par(
            self,
            iter0 = 0):
        assert (self.parameters['niter_todo'] % self.parameters['niter_stat'] == 0)
        assert (self.parameters['niter_todo'] % self.parameters['niter_out']  == 0)
        assert (self.parameters['niter_out']  % self.parameters['niter_stat'] == 0)
        if self.dns_type in ['NSVEparticles_no_output', 'NSVEcomplex_particles', 'NSVEparticles', 'static_field', 'kraichnan_field']:
            assert (self.parameters['niter_todo'] % self.parameters['niter_part'] == 0)
            assert (self.parameters['niter_out']  % self.parameters['niter_part'] == 0)
        _code.write_par(self, iter0 = iter0)
        with h5py.File(self.get_data_file_name(), 'r+') as ofile:
            ofile['code_info/exec_name'] = self.name
            kspace = self.get_kspace()
            for k in kspace.keys():
                ofile['kspace/' + k] = kspace[k]
            nshells = kspace['nshell'].shape[0]
            kspace = self.get_kspace()
            nshells = kspace['nshell'].shape[0]
            vec_stat_datasets = ['velocity', 'vorticity']
            scal_stat_datasets = []
            for k in vec_stat_datasets:
                time_chunk = 2**20//(8*3*3*nshells)
                time_chunk = max(time_chunk, 1)
                ofile.create_dataset('statistics/spectra/' + k + '_' + k,
                                     (1, nshells, 3, 3),
                                     chunks = (time_chunk, nshells, 3, 3),
                                     maxshape = (None, nshells, 3, 3),
                                     dtype = np.float64)
                time_chunk = 2**20//(8*4*10)
                time_chunk = max(time_chunk, 1)
                a = ofile.create_dataset('statistics/moments/' + k,
                                     (1, 10, 4),
                                     chunks = (time_chunk, 10, 4),
                                     maxshape = (None, 10, 4),
                                     dtype = np.float64)
                time_chunk = 2**20//(8*4*self.parameters['histogram_bins'])
                time_chunk = max(time_chunk, 1)
                ofile.create_dataset('statistics/histograms/' + k,
                                     (1,
                                      self.parameters['histogram_bins'],
                                      4),
                                     chunks = (time_chunk,
                                               self.parameters['histogram_bins'],
                                               4),
                                     maxshape = (None,
                                                 self.parameters['histogram_bins'],
                                                 4),
                                     dtype = np.int64)
            ofile['checkpoint'] = int(0)
        if (self.dns_type in ['NSVE', 'NSVE_no_output']):
            return None
        return None
    def job_parser_arguments(
            self,
            parser):
        parser.add_argument(
                '--ncpu',
                type = int,
                dest = 'ncpu',
                default = -1)
        parser.add_argument(
                '--np', '--nprocesses',
                metavar = 'NPROCESSES',
                help = 'number of mpi processes to use',
                type = int,
                dest = 'nb_processes',
                default = 4)
        parser.add_argument(
                '--ntpp', '--nthreads-per-process',
                type = int,
                dest = 'nb_threads_per_process',
                metavar = 'NTHREADS_PER_PROCESS',
                help = 'number of threads to use per MPI process',
                default = 1)
        parser.add_argument(
                '--no-debug',
                action = 'store_true',
                dest = 'no_debug')
        parser.add_argument(
                '--no-submit',
                action = 'store_true',
                dest = 'no_submit')
        parser.add_argument(
                '--environment',
                type = str,
                dest = 'environment',
                default = None)
        parser.add_argument(
                '--minutes',
                type = int,
                dest = 'minutes',
                default = 5,
                help = 'If environment supports it, this is the requested wall-clock-limit.')
        parser.add_argument(
               '--njobs',
               type = int, dest = 'njobs',
               default = 1)
        return None
    def simulation_parser_arguments(
            self,
            parser):
        parser.add_argument(
                '--simname',
                type = str, dest = 'simname',
                default = 'test')
        parser.add_argument(
               '-n', '--grid-size',
               type = int,
               dest = 'n',
               default = 32,
               metavar = 'N',
               help = 'code is run by default in a grid of NxNxN')
        for coord in ['x', 'y', 'z']:
            parser.add_argument(
                   '--L{0}'.format(coord), '--box-length-{0}'.format(coord),
                   type = float,
                   dest = 'L{0}'.format(coord),
                   default = 2.0,
                   metavar = 'length{0}'.format(coord),
                   help = 'length of the box in the {0} direction will be `length{0} x pi`'.format(coord))
        parser.add_argument(
                '--wd',
                type = str, dest = 'work_dir',
                default = './')
        parser.add_argument(
                '--precision',
                choices = ['single', 'double'],
                type = str,
                default = 'single')
        parser.add_argument(
                '--src-wd',
                type = str,
                dest = 'src_work_dir',
                default = '')
        parser.add_argument(
                '--src-simname',
                type = str,
                dest = 'src_simname',
                default = '')
        parser.add_argument(
                '--src-iteration',
                type = int,
                dest = 'src_iteration',
                default = 0)
        parser.add_argument(
               '--kMeta',
               type = float,
               dest = 'kMeta',
               default = 2.0)
        parser.add_argument(
               '--dtfactor',
               type = float,
               dest = 'dtfactor',
               default = 0.5,
               help = 'dt is computed as DTFACTOR / N')
        return None
    def particle_parser_arguments(
            self,
            parser):
        parser.add_argument(
               '--particle-rand-seed',
               type = int,
               dest = 'particle_rand_seed',
               default = None)
        parser.add_argument(
               '--pclouds',
               type = int,
               dest = 'pclouds',
               default = 1,
               help = ('number of particle clouds. Particle "clouds" '
                       'consist of particles distributed according to '
                       'pcloud-type.'))
        parser.add_argument(
                '--pcloud-type',
                choices = ['random-cube',
                           'regular-cube'],
                dest = 'pcloud_type',
                default = 'random-cube')
        parser.add_argument(
               '--particle-cloud-size',
               type = float,
               dest = 'particle_cloud_size',
               default = 2*np.pi)
        return None
    def add_parser_arguments(
            self,
            parser):
        subparsers = parser.add_subparsers(
                dest = 'DNS_class',
                help = 'type of simulation to run')
        subparsers.required = True
        parser_NSVE = subparsers.add_parser(
                'NSVE',
                help = 'plain Navier-Stokes vorticity formulation')
        self.simulation_parser_arguments(parser_NSVE)
        self.job_parser_arguments(parser_NSVE)
        self.parameters_to_parser_arguments(parser_NSVE)

        parser_NSVE_no_output = subparsers.add_parser(
                'NSVE_no_output',
                help = 'plain Navier-Stokes vorticity formulation, checkpoints are NOT SAVED')
        self.simulation_parser_arguments(parser_NSVE_no_output)
        self.job_parser_arguments(parser_NSVE_no_output)
        self.parameters_to_parser_arguments(parser_NSVE_no_output)

        parser_NSVEparticles_no_output = subparsers.add_parser(
                'NSVEparticles_no_output',
                help = 'plain Navier-Stokes vorticity formulation, with basic fluid tracers, checkpoints are NOT SAVED')

        parser_static_field = subparsers.add_parser(
                'static_field',
                help = 'static field with basic fluid tracers')

        parser_kraichnan_field = subparsers.add_parser(
                'kraichnan_field',
                help = 'Kraichnan field with basic fluid tracers')

        parser_NSVEp2 = subparsers.add_parser(
                'NSVEparticles',
                help = 'plain Navier-Stokes vorticity formulation, with basic fluid tracers')

        parser_NSVEp2p = subparsers.add_parser(
                'NSVEcomplex_particles',
                help = 'plain Navier-Stokes vorticity formulation, with oriented active particles')

        parser_NSVEp_extra = subparsers.add_parser(
                'NSVEp_extra_sampling',
                help = 'plain Navier-Stokes vorticity formulation, with basic fluid tracers, that sample velocity gradient, as well as pressure and its derivatives.')

        for parser in ['NSVEparticles_no_output', 'NSVEp2', 'NSVEp2p', 'NSVEp_extra', 'static_field', 'kraichnan_field']:
            eval('self.simulation_parser_arguments({0})'.format('parser_' + parser))
            eval('self.job_parser_arguments({0})'.format('parser_' + parser))
            eval('self.particle_parser_arguments({0})'.format('parser_' + parser))
            eval('self.parameters_to_parser_arguments({0})'.format('parser_' + parser))
            eval('self.parameters_to_parser_arguments('
                    'parser_{0},'
                    'self.NSVEp_extra_parameters)'.format(parser))
        return None
    def prepare_launch(
            self,
            args = [],
            extra_parameters = None):
        """Set up reasonable parameters.

        With the default Lundgren forcing applied in the band [2, 4],
        we can estimate the dissipation, therefore we can estimate
        :math:`k_M \\eta_K` and constrain the viscosity.

        In brief, the command line parameter :math:`k_M \\eta_K` is
        used in the following formula for :math:`\\nu` (:math:`N` is the
        number of real space grid points per coordinate):

        .. math::

            \\nu = \\left(\\frac{2 k_M \\eta_K}{N} \\right)^{4/3}

        With this choice, the average dissipation :math:`\\varepsilon`
        will be close to 0.4, and the integral scale velocity will be
        close to 0.77, yielding the approximate value for the Taylor
        microscale and corresponding Reynolds number:

        .. math::

            \\lambda \\approx 4.75\\left(\\frac{2 k_M \\eta_K}{N} \\right)^{4/6}, \\hskip .5in
            R_\\lambda \\approx 3.7 \\left(\\frac{N}{2 k_M \\eta_K} \\right)^{4/6}

        """
        opt = _code.prepare_launch(self, args = args)
        self.set_precision(opt.precision)
        self.dns_type = opt.DNS_class
        self.name = self.dns_type + '-' + self.fluid_precision + '-v' + TurTLE.__version__
        # merge parameters if needed
        if self.dns_type in ['NSVEparticles', 'NSVEcomplex_particles', 'NSVEparticles_no_output', 'NSVEp_extra_sampling', 'static_field', 'kraichnan_field']:
            for k in self.NSVEp_extra_parameters.keys():
                self.parameters[k] = self.NSVEp_extra_parameters[k]
        if type(extra_parameters) != type(None):
            if self.dns_type in extra_parameters.keys():
                for k in extra_parameters[self.dns_type].keys():
                    self.parameters[k] = extra_parameters[self.dns_type][k]
        if ((self.parameters['niter_todo'] % self.parameters['niter_out']) != 0):
            self.parameters['niter_out'] = self.parameters['niter_todo']
        if len(opt.src_work_dir) == 0:
            opt.src_work_dir = os.path.realpath(opt.work_dir)
        if type(opt.dkx) == type(None):
            opt.dkx = 2. / opt.Lx
        if type(opt.dky) == type(None):
            opt.dky = 2. / opt.Ly
        if type(opt.dkz) == type(None):
            opt.dkz = 2. / opt.Lz
        if type(opt.nx) == type(None):
            opt.nx = opt.n
        if type(opt.ny) == type(None):
            opt.ny = opt.n
        if type(opt.nz) == type(None):
            opt.nz = opt.n
        if type(opt.fk0) == type(None):
            opt.fk0 = self.parameters['fk0']
        if type(opt.fk1) == type(None):
            opt.fk1 = self.parameters['fk1']
        if type(opt.injection_rate) == type(None):
            opt.injection_rate = self.parameters['injection_rate']
        if type(opt.dealias_type) == type(None):
            opt.dealias_type = self.parameters['dealias_type']
        if (opt.nx > opt.n or
            opt.ny > opt.n or
            opt.nz > opt.n):
            opt.n = min(opt.nx, opt.ny, opt.nz)
            print("Warning: '-n' parameter changed to minimum of nx, ny, nz. This affects the computation of nu.")
        self.parameters['dt'] = (opt.dtfactor / opt.n)
        self.parameters['nu'] = (opt.kMeta * 2 / opt.n)**(4./3)
        # check value of kMax
        kM = opt.n * 0.5
        if opt.dealias_type == 1:
            kM *= 0.8
        # tweak forcing/viscosity based on forcint type
        if opt.forcing_type == 'linear':
            # custom famplitude for 288 and 576
            if opt.n == 288:
                self.parameters['famplitude'] = 0.45
            elif opt.n == 576:
                self.parameters['famplitude'] = 0.47
        elif opt.forcing_type == 'fixed_energy_injection_rate':
            # use the fact that mean dissipation rate is equal to injection rate
            self.parameters['nu'] = (
                    opt.injection_rate *
                    (opt.kMeta / kM)**4)**(1./3)
        elif opt.forcing_type == 'fixed_energy':
            kf = 1. / (1./opt.fk0 +
                       1./opt.fk1)
            self.parameters['nu'] = (
                    (opt.kMeta / kM)**(4./3) *
                    (np.pi / kf)**(1./3) *
                    (2*self.parameters['energy'] / 3)**0.5)
        if type(opt.checkpoints_per_file) == type(None):
            # hardcoded FFTW complex representation size
            field_size = 3*(opt.nx+2)*opt.ny*opt.nz*self.fluid_dtype.itemsize
            checkpoint_size = field_size
            if self.dns_type in ['kraichnan_field', 'static_field', 'NSVEparticles', 'NSVEcomplex_particles', 'NSVEparticles_no_output', 'NSVEp_extra_sampling']:
                rhs_size = self.parameters['tracers0_integration_steps']
                if type(opt.tracers0_integration_steps) != type(None):
                    rhs_size = opt.tracers0_integration_steps
                nparticles = opt.nparticles
                if type(nparticles) == type(None):
                    nparticles = self.NSVEp_extra_parameters['nparticles']
                particle_size = (1+rhs_size)*3*nparticles*8
                checkpoint_size += particle_size
            if checkpoint_size < 1e9:
                opt.checkpoints_per_file = int(1e9 / checkpoint_size)
        self.pars_from_namespace(opt)
        return opt
    def launch(
            self,
            args = [],
            **kwargs):
        opt = self.prepare_launch(args = args)
        self.launch_jobs(opt = opt, **kwargs)
        return None
    def get_checkpoint_0_fname(self):
        return os.path.join(
                    self.work_dir,
                    self.simname + '_checkpoint_0.h5')
    def get_checkpoint_fname(self, iteration = 0):
        checkpoint = (iteration // self.parameters['niter_out']) // self.parameters['checkpoints_per_file']
        return os.path.join(
                    self.work_dir,
                    self.simname + '_checkpoint_{0}.h5'.format(checkpoint))
    def generate_tracer_state(
            self,
            rseed = None,
            species = 0,
            integration_steps = None,
            ncomponents = 3):
        try:
            if type(integration_steps) == type(None):
                integration_steps = self.NSVEp_extra_parameters['tracers0_integration_steps']
            if 'tracers{0}_integration_steps'.format(species) in self.parameters.keys():
                integration_steps = self.parameters['tracers{0}_integration_steps'.format(species)]
            if self.dns_type == 'NSVEcomplex_particles' and species == 0:
                ncomponents = 6
            with h5py.File(self.get_checkpoint_0_fname(), 'a') as data_file:
                nn = self.parameters['nparticles']
                if not 'tracers{0}'.format(species) in data_file.keys():
                    data_file.create_group('tracers{0}'.format(species))
                    data_file.create_group('tracers{0}/rhs'.format(species))
                    data_file.create_group('tracers{0}/state'.format(species))
                data_file['tracers{0}/rhs'.format(species)].create_dataset(
                        '0',
                        shape = (integration_steps, nn, ncomponents,),
                        dtype = np.float)
                dset = data_file['tracers{0}/state'.format(species)].create_dataset(
                        '0',
                        shape = (nn, ncomponents,),
                        dtype = np.float)
                if not type(rseed) == type(None):
                    np.random.seed(rseed)
                cc = int(0)
                batch_size = int(1e6)
                def get_random_phases(npoints):
                    return np.random.random(
                                (npoints, 3))*2*np.pi
                def get_random_versors(npoints):
                    bla = np.random.normal(
                            size = (npoints, 3))
                    bla  /= np.sum(bla**2, axis = 1)[:, None]**.5
                    return bla
                while nn > 0:
                    if nn > batch_size:
                        dset[cc*batch_size:(cc+1)*batch_size, :3] = get_random_phases(batch_size)
                        if dset.shape[1] == 6:
                            dset[cc*batch_size:(cc+1)*batch_size, 3:] = get_random_versors(batch_size)
                        nn -= batch_size
                    else:
                        dset[cc*batch_size:cc*batch_size+nn, :3] = get_random_phases(nn)
                        if dset.shape[1] == 6:
                            dset[cc*batch_size:cc*batch_size+nn, 3:] = get_random_versors(nn)
                        nn = 0
                    cc += 1
        except Exception as e:
            print(e)
        return None
    def generate_vector_field(
            self,
            rseed = 7547,
            spectra_slope = 1.,
            amplitude = 1.,
            iteration = 0,
            field_name = 'vorticity',
            write_to_file = False,
            # to switch to constant field, use generate_data_3D_uniform
            # for scalar_generator
            scalar_generator = tools.generate_data_3D):
        """generate vector field.

        The generated field is not divergence free, but it has the proper
        shape.

        :param rseed: seed for random number generator
        :param spectra_slope: spectrum of field will look like k^(-p)
        :param amplitude: all amplitudes are multiplied with this value
        :param iteration: the field is written at this iteration
        :param field_name: the name of the field being generated
        :param write_to_file: should we write the field to file?
        :param scalar_generator: which function to use for generating the
            individual components.
            Possible values: TurTLE.tools.generate_data_3D,
            TurTLE.tools.generate_data_3D_uniform
        :type rseed: int
        :type spectra_slope: float
        :type amplitude: float
        :type iteration: int
        :type field_name: str
        :type write_to_file: bool
        :type scalar_generator: function

        :returns: ``Kdata``, a complex valued 4D ``numpy.array`` that uses the
            transposed FFTW layout.
            Kdata[ky, kz, kx, i] is the amplitude of mode (kx, ky, kz) for
            the i-th component of the field.
            (i.e. x is the fastest index and z the slowest index in the
            real-space representation).
        """
        np.random.seed(rseed)
        Kdata00 = scalar_generator(
                self.parameters['nz'],
                self.parameters['ny'],
                self.parameters['nx'],
                p = spectra_slope,
                amplitude = amplitude).astype(self.ctype)
        Kdata01 = scalar_generator(
                self.parameters['nz'],
                self.parameters['ny'],
                self.parameters['nx'],
                p = spectra_slope,
                amplitude = amplitude).astype(self.ctype)
        Kdata02 = scalar_generator(
                self.parameters['nz'],
                self.parameters['ny'],
                self.parameters['nx'],
                p = spectra_slope,
                amplitude = amplitude).astype(self.ctype)
        Kdata0 = np.zeros(
                Kdata00.shape + (3,),
                Kdata00.dtype)
        Kdata0[..., 0] = Kdata00
        Kdata0[..., 1] = Kdata01
        Kdata0[..., 2] = Kdata02
        Kdata1 = tools.padd_with_zeros(
                Kdata0,
                self.parameters['nz'],
                self.parameters['ny'],
                self.parameters['nx'])
        if write_to_file:
            Kdata1.tofile(
                    os.path.join(self.work_dir,
                                 self.simname + "_c{0}_i{1:0>5x}".format(field_name, iteration)))
        return Kdata1
    def copy_complex_field(
            self,
            src_file_name,
            src_dset_name,
            dst_file,
            dst_dset_name,
            make_link = True):
        # I define a min_shape thingie, but for now I only trust this method for
        # the case of increasing/decreasing by the same factor in all directions.
        # in principle we could write something more generic, but i'm not sure
        # how complicated that would be
        dst_shape = (self.parameters['ny'],
                     self.parameters['nz'],
                     (self.parameters['nx']+2) // 2,
                     3)
        src_file = h5py.File(src_file_name, 'r')
        if (src_file[src_dset_name].shape == dst_shape):
            dst_file[dst_dset_name] = h5py.ExternalLink(
                    src_file_name,
                    src_dset_name)
        else:
            min_shape = (min(dst_shape[0], src_file[src_dset_name].shape[0]),
                         min(dst_shape[1], src_file[src_dset_name].shape[1]),
                         min(dst_shape[2], src_file[src_dset_name].shape[2]),
                         3)
            src_shape = src_file[src_dset_name].shape
            dst_file.create_dataset(
                    dst_dset_name,
                    shape = dst_shape,
                    dtype = np.dtype(self.ctype),
                    fillvalue = complex(0))
            for kz in range(min_shape[0]//2):
                dst_file[dst_dset_name][kz,:min_shape[1]//2, :min_shape[2]] = \
                        src_file[src_dset_name][kz, :min_shape[1]//2, :min_shape[2]]
                dst_file[dst_dset_name][kz,
                                        dst_shape[1] - min_shape[1]//2+1:,
                                        :min_shape[2]] = \
                        src_file[src_dset_name][kz,
                                                src_shape[1] - min_shape[1]//2+1,
                                                :min_shape[2]]
                if kz > 0:
                    dst_file[dst_dset_name][-kz,:min_shape[1]//2, :min_shape[2]] = \
                            src_file[src_dset_name][-kz, :min_shape[1]//2, :min_shape[2]]
                    dst_file[dst_dset_name][-kz,
                                            dst_shape[1] - min_shape[1]//2+1:,
                                            :min_shape[2]] = \
                            src_file[src_dset_name][-kz,
                                                    src_shape[1] - min_shape[1]//2+1,
                                                    :min_shape[2]]
        return None
    def generate_particle_data(
            self,
            opt = None):
        if self.parameters['nparticles'] > 0:
            self.generate_tracer_state(
                    species = 0,
                    rseed = opt.particle_rand_seed)
            if not os.path.exists(self.get_particle_file_name()):
                with h5py.File(self.get_particle_file_name(), 'w') as particle_file:
                    particle_file.create_group('tracers0/position')
                    particle_file.create_group('tracers0/velocity')
                    particle_file.create_group('tracers0/acceleration')
                    if self.dns_type in ['NSVEcomplex_particles']:
                        particle_file.create_group('tracers0/orientation')
                        particle_file.create_group('tracers0/velocity_gradient')
                    if self.dns_type in ['NSVEp_extra_sampling']:
                        particle_file.create_group('tracers0/velocity_gradient')
                        particle_file.create_group('tracers0/pressure')
                        particle_file.create_group('tracers0/pressure_gradient')
                        particle_file.create_group('tracers0/pressure_Hessian')
        return None
    def generate_initial_condition(
            self,
            opt = None):
        # take care of fields' initial condition
        # first, check if initial field exists
        need_field = False
        if not os.path.exists(self.get_checkpoint_0_fname()):
            need_field = True
        else:
            f = h5py.File(self.get_checkpoint_0_fname(), 'r')
            try:
                dset = f['vorticity/complex/0']
                need_field = (dset.shape == (self.parameters['ny'],
                                             self.parameters['nz'],
                                             self.parameters['nx']//2+1,
                                             3))
            except:
                need_field = True
            f.close()
        if need_field:
            f = h5py.File(self.get_checkpoint_0_fname(), 'a')
            if len(opt.src_simname) > 0:
                source_cp = 0
                src_file = 'not_a_file'
                while True:
                    src_file = os.path.join(
                        os.path.realpath(opt.src_work_dir),
                        opt.src_simname + '_checkpoint_{0}.h5'.format(source_cp))
                    f0 = h5py.File(src_file, 'r')
                    if '{0}'.format(opt.src_iteration) in f0['vorticity/complex'].keys():
                        f0.close()
                        break
                    source_cp += 1
                self.copy_complex_field(
                        src_file,
                        'vorticity/complex/{0}'.format(opt.src_iteration),
                        f,
                        'vorticity/complex/{0}'.format(0))
            else:
                data = self.generate_vector_field(
                       write_to_file = False,
                       spectra_slope = 2.0,
                       amplitude = 0.05)
                f['vorticity/complex/{0}'.format(0)] = data
            f.close()
        # now take care of particles' initial condition
        if self.dns_type in ['kraichnan_field', 'static_field', 'NSVEparticles', 'NSVEcomplex_particles', 'NSVEparticles_no_output', 'NSVEp_extra_sampling']:
            self.generate_particle_data(opt = opt)
        return None
    def launch_jobs(
            self,
            opt = None):
        if not os.path.exists(self.get_data_file_name()):
            self.generate_initial_condition(opt = opt)
            self.write_par()
        self.run(
                nb_processes = opt.nb_processes,
                nb_threads_per_process = opt.nb_threads_per_process,
                njobs = opt.njobs,
                hours = opt.minutes // 60,
                minutes = opt.minutes % 60,
                no_submit = opt.no_submit,
                no_debug = opt.no_debug)
        return None