Skip to content
Snippets Groups Projects
kspace.hpp 12.81 KiB
/**********************************************************************
*                                                                     *
*  Copyright 2015 Max Planck Institute                                *
*                 for Dynamics and Self-Organization                  *
*                                                                     *
*  This file is part of TurTLE.                                       *
*                                                                     *
*  TurTLE is free software: you can redistribute it and/or modify     *
*  it under the terms of the GNU General Public License as published  *
*  by the Free Software Foundation, either version 3 of the License,  *
*  or (at your option) any later version.                             *
*                                                                     *
*  TurTLE is distributed in the hope that it will be useful,          *
*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      *
*  GNU General Public License for more details.                       *
*                                                                     *
*  You should have received a copy of the GNU General Public License  *
*  along with TurTLE.  If not, see <http://www.gnu.org/licenses/>     *
*                                                                     *
* Contact: Cristian.Lalescu@ds.mpg.de                                 *
*                                                                     *
**********************************************************************/



#include <hdf5.h>
#include <vector>
#include <string>
#include "omputils.hpp"
#include "fftw_interface.hpp"
#include "field_layout.hpp"

#ifndef KSPACE_HPP

#define KSPACE_HPP

enum field_backend {FFTW};
enum kspace_dealias_type {ONE_HALF, TWO_THIRDS, SMOOTH};

/** \brief A class for handling Fourier representation tasks.
 *
 *      - contains wavenumber information (specific to each MPI process).
 *      This includes values of kx, ky, kz, including lowest modes dkx etc,
 *      as well as number of modes within sferical shells and mean wavenumber
 *      within shells.
 *      - has methods for spectrum computation and similar.
 *      - has methods for filtering.
 *      - has CLOOP methods, useful for computing arbitrary formulas over the
 *      Fourier space grid (i.e. use lambda expressions).
 */

template <field_backend be,
          kspace_dealias_type dt>
class kspace
{
    public:
        /* relevant field layout */
        field_layout<ONE> *layout;

        /* physical parameters */
        double dkx, dky, dkz, dk, dk2;

        /* mode and dealiasing information */
        double kMx, kMy, kMz, kM, kM2;
        std::vector<double> kx, ky, kz;
        std::vector<double> kshell;
        std::vector<int64_t> nshell;
        int nshells;
        /* methods */
        template <field_components fc>
        kspace(
                const field_layout<fc> *source_layout,
                const double DKX = 1.0,
                const double DKY = 1.0,
                const double DKZ = 1.0);
        ~kspace();

        int store(hid_t stat_file);

        template <typename rnumber,
                  field_components fc>
        void low_pass(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const double kmax);

        template <typename rnumber,
                  field_components fc>
        void Gauss_filter(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const double sigma);

        template <typename rnumber,
                  field_components fc>
        void ball_filter(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const double sigma);

        template <typename rnumber,
                  field_components fc>
        void general_M_filter(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const double sigma);

        template <typename rnumber,
                  field_components fc>
        int filter(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const double wavenumber,
                std::string filter_type = std::string("Gauss"));

        template <typename rnumber,
                  field_components fc>
        int filter_calibrated_ell(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const double wavenumber,
                std::string filter_type = std::string("Gauss"));

        template <typename rnumber,
                  field_components fc>
        void dealias(typename fftw_interface<rnumber>::complex *__restrict__ a);

        template <typename rnumber,
                  field_components fc>
        void cospectrum(
                const rnumber(* __restrict__ a)[2],
                const rnumber(* __restrict__ b)[2],
                const hid_t group,
                const std::string dset_name,
                const hsize_t toffset,
		        const double wavenumber_exp = 0);

        template <typename rnumber,
                  field_components fc>
        void cospectrum(
                const rnumber(* __restrict__ a)[2],
                const hid_t group,
                const std::string dset_name,
                const hsize_t toffset,
		        const double wavenumber_exp = 0);

        template <typename rnumber,
                  field_components fc>
        void cospectrum(
                const rnumber(* __restrict__ a)[2],
                std::vector<double> &spec,
		        const double wavenumber_exp = 0);

        template <typename rnumber,
                  field_components fc>
        double L2norm(
                const rnumber(* __restrict__ a)[2]);

        template <class func_type>
        void CLOOP(func_type expression)
        {
            start_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
            #pragma omp parallel
            {
                const hsize_t start = OmpUtils::ForIntervalStart(this->layout->subsizes[1]);
                const hsize_t end = OmpUtils::ForIntervalEnd(this->layout->subsizes[1]);

                for (hsize_t yindex = 0; yindex < this->layout->subsizes[0]; yindex++){
                    for (hsize_t zindex = start; zindex < end; zindex++){
                        const ptrdiff_t cindex = (
                                yindex*this->layout->subsizes[1]*this->layout->subsizes[2] +
                                zindex*this->layout->subsizes[2]);
                        for (hsize_t xindex = 0; xindex < this->layout->subsizes[2]; xindex++)
                        {
                            expression(cindex + xindex, xindex, yindex, zindex);
                        }
                    }
                }
            }
            finish_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
        }
        template <class func_type>
        void CLOOP_simd(func_type expression)
        {
            start_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
            #pragma omp parallel
            {
                const hsize_t start = OmpUtils::ForIntervalStart(this->layout->subsizes[1]);
                const hsize_t end = OmpUtils::ForIntervalEnd(this->layout->subsizes[1]);

                for (hsize_t yindex = 0; yindex < this->layout->subsizes[0]; yindex++){
                    #pragma omp simd
                    for (hsize_t zindex = start; zindex < end; zindex++){
                        const ptrdiff_t cindex = (
                                yindex*this->layout->subsizes[1]*this->layout->subsizes[2] +
                                zindex*this->layout->subsizes[2]);
                        for (hsize_t xindex = 0; xindex < this->layout->subsizes[2]; xindex++)
                        {
                            expression(cindex + xindex, xindex, yindex, zindex);
                        }
                    }
                }
            }
            finish_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
        }
        template <class func_type>
        void CLOOP_K2(func_type expression)
        {
            start_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
            #pragma omp parallel
            {
                const hsize_t start = OmpUtils::ForIntervalStart(this->layout->subsizes[1]);
                const hsize_t end = OmpUtils::ForIntervalEnd(this->layout->subsizes[1]);
                for (hsize_t yindex = 0; yindex < this->layout->subsizes[0]; yindex++){
                    for (hsize_t zindex = start; zindex < end; zindex++){
                        const ptrdiff_t cindex = yindex*this->layout->subsizes[1]*this->layout->subsizes[2]
                                            + zindex*this->layout->subsizes[2];
                        for (hsize_t xindex = 0; xindex < this->layout->subsizes[2]; xindex++)
                        {
                            expression(
                                    cindex+xindex,
                                    xindex,
                                    yindex,
                                    zindex,
                                    (this->kx[xindex]*this->kx[xindex] +
                                     this->ky[yindex]*this->ky[yindex] +
                                     this->kz[zindex]*this->kz[zindex]));
                        }
                    }
                }
            }
            finish_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
        }
        template <class func_type>
        void CLOOP_K2_simd(func_type expression)
        {
            start_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
            #pragma omp parallel
            {
                const hsize_t start = OmpUtils::ForIntervalStart(this->layout->subsizes[1]);
                const hsize_t end = OmpUtils::ForIntervalEnd(this->layout->subsizes[1]);

                for (hsize_t yindex = 0; yindex < this->layout->subsizes[0]; yindex++){
                    #pragma omp simd
                    for (hsize_t zindex = start; zindex < end; zindex++){
                        const ptrdiff_t cindex = yindex*this->layout->subsizes[1]*this->layout->subsizes[2]
                                            + zindex*this->layout->subsizes[2];
                        for (hsize_t xindex = 0; xindex < this->layout->subsizes[2]; xindex++)
                        {
                            expression(
                                    cindex+xindex,
                                    xindex,
                                    yindex,
                                    zindex,
                                    (this->kx[xindex]*this->kx[xindex] +
                                     this->ky[yindex]*this->ky[yindex] +
                                     this->kz[zindex]*this->kz[zindex]));
                        }
                    }
                }
            }
            finish_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
        }
        template <class func_type>
        void CLOOP_K2_NXMODES(func_type expression)
        {
            start_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
            #pragma omp parallel
            {
                const hsize_t start = OmpUtils::ForIntervalStart(this->layout->subsizes[1]);
                const hsize_t end = OmpUtils::ForIntervalEnd(this->layout->subsizes[1]);

                for (hsize_t yindex = 0; yindex < this->layout->subsizes[0]; yindex++){
                    for (hsize_t zindex = start; zindex < end; zindex++){
                        const ptrdiff_t cindex = yindex*this->layout->subsizes[1]*this->layout->subsizes[2]
                                            + zindex*this->layout->subsizes[2];
                        const double k2 = (
                                this->ky[yindex]*this->ky[yindex] +
                                this->kz[zindex]*this->kz[zindex]);
                        expression(cindex, 0, yindex, zindex, k2, 1);
                        for (hsize_t xindex = 1; xindex < this->layout->subsizes[2]; xindex++)
                        {
                            expression(cindex+xindex, xindex, yindex, zindex, k2 + this->kx[xindex]*this->kx[xindex], 2);
                        }
                    }
                }
            }
            finish_mpi_profiling_zone(turtle_mpi_pcontrol::FIELD);
        }
        template <typename rnumber>
        void project_divfree(
                typename fftw_interface<rnumber>::complex *__restrict__ a,
                const bool maintain_energy = false);
        // TODO: can the following be done in a cleaner way?
        template <typename rnumber>
        void force_divfree(typename fftw_interface<rnumber>::complex *__restrict__ a){
            this->template project_divfree<rnumber>(a, false);
        }
        template <typename rnumber>
        void rotate_divfree(typename fftw_interface<rnumber>::complex *__restrict__ a);
};

#endif//KSPACE_HPP