From dea6ea3c2260d855a4b833c9044e5f4875d13cb5 Mon Sep 17 00:00:00 2001
From: Berenger Bramas <bbramas@mpcdf.mpg.de>
Date: Wed, 18 Jan 2017 17:11:17 +0100
Subject: [PATCH] add option to force fftw estimate

---
 bfps/FluidConvert.py           | 2 +-
 bfps/_code.py                  | 1 +
 bfps/cpp/fftw_interface.hpp    | 7 +++++++
 bfps/cpp/field.hpp             | 2 +-
 bfps/cpp/field_descriptor.cpp  | 8 ++++----
 bfps/cpp/fluid_solver_base.hpp | 2 +-
 setup.py                       | 8 ++++++++
 7 files changed, 23 insertions(+), 7 deletions(-)

diff --git a/bfps/FluidConvert.py b/bfps/FluidConvert.py
index 8d31784a..d924f2a1 100644
--- a/bfps/FluidConvert.py
+++ b/bfps/FluidConvert.py
@@ -98,7 +98,7 @@ class FluidConvert(_fluid_particle_base):
                         nx, ny, nz,
                         dkx, dky, dkz,
                         dealias_type,
-                        FFTW_PATIENT);
+                        DEFAULT_FFTW_FLAG);
                 //endcpp
                 """.format(self.C_dtype)
         self.fluid_loop += """
diff --git a/bfps/_code.py b/bfps/_code.py
index 2e320625..71e869a0 100644
--- a/bfps/_code.py
+++ b/bfps/_code.py
@@ -54,6 +54,7 @@ class _code(_base):
                 #include "base.hpp"
                 #include "fluid_solver.hpp"
                 #include "scope_timer.hpp"
+                #include "fftw_interface.hpp"
                 #include <iostream>
                 #include <hdf5.h>
                 #include <string>
diff --git a/bfps/cpp/fftw_interface.hpp b/bfps/cpp/fftw_interface.hpp
index ec318b3b..2b2e5074 100644
--- a/bfps/cpp/fftw_interface.hpp
+++ b/bfps/cpp/fftw_interface.hpp
@@ -27,6 +27,13 @@
 
 #include <fftw3-mpi.h>
 
+#ifdef USE_FFTWESTIMATE
+#define DEFAULT_FFTW_FLAG FFTW_ESTIMATE
+#warning You are using FFTW estimate
+#else
+#define DEFAULT_FFTW_FLAG FFTW_PATIENT
+#endif
+
 template <class realtype>
 class fftw_interface;
 
diff --git a/bfps/cpp/field.hpp b/bfps/cpp/field.hpp
index 3cb1424e..ca341a33 100644
--- a/bfps/cpp/field.hpp
+++ b/bfps/cpp/field.hpp
@@ -71,7 +71,7 @@ class field
                 const int ny,
                 const int nz,
                 const MPI_Comm COMM_TO_USE,
-                const unsigned FFTW_PLAN_RIGOR = FFTW_PATIENT);
+                const unsigned FFTW_PLAN_RIGOR = DEFAULT_FFTW_FLAG);
         ~field();
 
         int io(
diff --git a/bfps/cpp/field_descriptor.cpp b/bfps/cpp/field_descriptor.cpp
index a2d0209a..20c63426 100644
--- a/bfps/cpp/field_descriptor.cpp
+++ b/bfps/cpp/field_descriptor.cpp
@@ -361,7 +361,7 @@ int field_descriptor<rnumber>::transpose(
                 this->sizes[0], this->slice_size,
             input, output,
             this->comm,
-            FFTW_PATIENT);
+            DEFAULT_FFTW_FLAG);
     fftw_interface<rnumber>::execute(tplan);
     fftw_interface<rnumber>::destroy_plan(tplan);
     return EXIT_SUCCESS;
@@ -389,7 +389,7 @@ int field_descriptor<rnumber>::transpose(
                 FFTW_MPI_DEFAULT_BLOCK,
                 (rnumber*)input, (rnumber*)output,
                 this->comm,
-                FFTW_PATIENT);
+                DEFAULT_FFTW_FLAG);
         fftw_interface<rnumber>::execute(tplan);
         fftw_interface<rnumber>::destroy_plan(tplan);
         break;
@@ -449,7 +449,7 @@ int field_descriptor<rnumber>::interleave(
                 a,
                 a,
                 /*kind*/nullptr,
-                FFTW_PATIENT);
+                DEFAULT_FFTW_FLAG);
     fftw_interface<rnumber>::execute(tmp);
     fftw_interface<rnumber>::destroy_plan(tmp);
     return EXIT_SUCCESS;
@@ -478,7 +478,7 @@ int field_descriptor<rnumber>::interleave(
                 a,
                 a,
                 +1,
-                FFTW_PATIENT);
+                DEFAULT_FFTW_FLAG);
     fftw_interface<rnumber>::execute(tmp);
     fftw_interface<rnumber>::destroy_plan(tmp);
     return EXIT_SUCCESS;
diff --git a/bfps/cpp/fluid_solver_base.hpp b/bfps/cpp/fluid_solver_base.hpp
index 97ac5a37..02fd8173 100644
--- a/bfps/cpp/fluid_solver_base.hpp
+++ b/bfps/cpp/fluid_solver_base.hpp
@@ -82,7 +82,7 @@ class fluid_solver_base
                 double DKY = 1.0,
                 double DKZ = 1.0,
                 int DEALIAS_TYPE = 0,
-                unsigned FFTW_PLAN_RIGOR = FFTW_PATIENT);
+                unsigned FFTW_PLAN_RIGOR = DEFAULT_FFTW_FLAG);
         ~fluid_solver_base();
 
         void low_pass_Fourier(cnumber *__restrict__ a, int howmany, double kmax);
diff --git a/setup.py b/setup.py
index 8d4ae0f9..f5d1c1c6 100644
--- a/setup.py
+++ b/setup.py
@@ -137,6 +137,7 @@ libraries += extra_libraries
 
 def compile_bfps_library(
         use_timingoutput = False,
+        use_fftwestimate = False,
         extra_compile_args = None):
     """
         use_timingoutput sets the USE_TIMINGOUTPUT definition,
@@ -157,6 +158,8 @@ def compile_bfps_library(
         need_to_compile = (latest > libtime)
     if use_timingoutput:
         extra_compile_args += ['-DUSE_TIMINGOUTPUT']
+    if use_fftwestimate:
+        extra_compile_args += ['-DUSE_FFTWESTIMATE']
     for fname in src_file_list:
         ifile = 'bfps/cpp/' + fname + '.cpp'
         ofile = 'obj/' + fname + '.o'
@@ -200,12 +203,15 @@ class CompileLibCommand(distutils.cmd.Command):
     description = 'Compile bfps library.'
     user_options = [
             ('timing-output=', None, 'Toggle timing output.'),
+            ('fftw-estimate=', None, 'Use FFTW ESTIMATE.'),
             ]
     def initialize_options(self):
         self.timing_output = 0
+        self.fftw_estimate = 0
         return None
     def finalize_options(self):
         self.timing_output = (int(self.timing_output) == 1)
+        self.fftw_estimate = (int(self.fftw_estimate) == 1)
         return None
     def run(self):
         if not os.path.isdir('obj'):
@@ -224,6 +230,8 @@ class CompileLibCommand(distutils.cmd.Command):
         eca = extra_compile_args
         if self.timing_output:
             eca += ['-DUSE_TIMINGOUTPUT']
+        if self.fftw_estimate:
+            eca += ['-DUSE_FFTWESTIMATE']
         for fname in src_file_list:
             ifile = 'bfps/cpp/' + fname + '.cpp'
             ofile = 'obj/' + fname + '.o'
-- 
GitLab