From 80f8280b9be03eedcbad6c4642f14227a8d6eac3 Mon Sep 17 00:00:00 2001
From: Chichi Lalescu <chichilalescu@gmail.com>
Date: Wed, 3 May 2017 22:06:28 +0200
Subject: [PATCH] improve modularity of DNS class

---
 bfps/DNS.py | 105 +++++++++++++++++++++++++++-------------------------
 1 file changed, 54 insertions(+), 51 deletions(-)

diff --git a/bfps/DNS.py b/bfps/DNS.py
index e19bcc79..b17d2611 100644
--- a/bfps/DNS.py
+++ b/bfps/DNS.py
@@ -60,84 +60,87 @@ class DNS(_code):
         self.rtype = self.fluid_dtype
         if self.rtype == np.float32:
             self.ctype = np.dtype(np.complex64)
-            self.C_dtype = 'float'
+            self.C_field_dtype = 'float'
         elif self.rtype == np.float64:
             self.ctype = np.dtype(np.complex128)
-            self.C_dtype = 'double'
-        if self.dns_type == 'NSVE':
-            self.parameters['dealias_type'] = 1
-            self.parameters['dkx'] = 1.0
-            self.parameters['dky'] = 1.0
-            self.parameters['dkz'] = 1.0
-            self.parameters['niter_todo'] = 8
-            self.parameters['niter_part'] = 1
-            self.parameters['niter_stat'] = 1
-            self.parameters['niter_out'] = 1024
-            self.parameters['nparticles'] = 0
-            self.parameters['dt'] = 0.01
-            self.parameters['nu'] = float(0.1)
-            self.parameters['fmode'] = 1
-            self.parameters['famplitude'] = float(0.5)
-            self.parameters['fk0'] = float(2.0)
-            self.parameters['fk1'] = float(4.0)
-            self.parameters['forcing_type'] = 'linear'
-            self.parameters['histogram_bins'] = int(256)
-            self.parameters['max_velocity_estimate'] = float(1)
-            self.parameters['max_vorticity_estimate'] = float(1)
-            self.parameters['checkpoints_per_file'] = int(1)
+            self.C_field_dtype = 'double'
         self.version_message = (
                 '/***********************************************************************\n' +
                 '* this code automatically generated by bfps\n' +
                 '* version {0}\n'.format(bfps.__version__) +
                 '***********************************************************************/\n\n\n')
-        self.includes = """
-                //begincpp
-                #include "base.hpp"
-                #include "scope_timer.hpp"
-                #include "fftw_interface.hpp"
-                #include "full_code/main_code.hpp"
-                #include <iostream>
-                #include <hdf5.h>
-                #include <string>
-                #include <cstring>
-                #include <fftw3-mpi.h>
-                #include <omp.h>
-                #include <cfenv>
-                #include <cstdlib>
-                //endcpp
-                """
-        self.includes += '#include "full_code/{0}.hpp"\n'.format(self.dns_type)
-        self.variables = ''
-        self.definitions = ''
+        self.include_list = [
+                '"base.hpp"',
+                '"scope_timer.hpp"',
+                '"fftw_interface.hpp"',
+                '"full_code/main_code.hpp"',
+                '<iostream>',
+                '<hdf5.h>',
+                '<string>',
+                '<cstring>',
+                '<fftw3-mpi.h>',
+                '<omp.h>',
+                '<cfenv>',
+                '<cstdlib>',
+                '"full_code/{0}.hpp"\n'.format(self.dns_type)]
         self.main = """
                 int main(int argc, char *argv[])
                 {{
                     bool fpe = (
                         (getenv("BFPS_FPE_OFF") == nullptr) ||
                         (getenv("BFPS_FPE_OFF") != std::string("TRUE")));
-                    return main_code<{0}>(argc, argv, fpe);
+                    return main_code< {0} >(argc, argv, fpe);
                 }}
-                """.format(self.dns_type + '<{0}>'.format(self.C_dtype))
+                """.format(self.dns_type + '<{0}>'.format(self.C_field_dtype))
         self.host_info = {'type'        : 'cluster',
                           'environment' : None,
                           'deltanprocs' : 1,
                           'queue'       : '',
                           'mail_address': '',
                           'mail_events' : None}
+        self.generate_default_parameters()
+        return None
+    def generate_default_parameters(self):
+        # these parameters are relevant for all DNS classes
+        self.parameters['dealias_type'] = 1
+        self.parameters['dkx'] = 1.0
+        self.parameters['dky'] = 1.0
+        self.parameters['dkz'] = 1.0
+        self.parameters['niter_todo'] = 8
+        self.parameters['niter_stat'] = 1
+        self.parameters['niter_out'] = 8
+        self.parameters['checkpoints_per_file'] = int(1)
+        self.parameters['dt'] = 0.01
+        self.parameters['nu'] = float(0.1)
+        self.parameters['fmode'] = 1
+        self.parameters['famplitude'] = float(0.5)
+        self.parameters['fk0'] = float(2.0)
+        self.parameters['fk1'] = float(4.0)
+        self.parameters['forcing_type'] = 'linear'
+        self.parameters['histogram_bins'] = int(256)
+        self.parameters['max_velocity_estimate'] = float(1)
+        self.parameters['max_vorticity_estimate'] = float(1)
+        # parameters specific to particle version
+        if self.dns_type == 'NSVEp':
+            self.parameters['niter_part'] = 1
+            self.parameters['nparticles'] = 0
         return None
     def write_src(self):
+        self.includes = '\n'.join(
+                ['#include ' + hh
+                 for hh in self.include_list])
         with open(self.name + '.cpp', 'w') as outfile:
-            outfile.write(self.version_message)
-            outfile.write(self.includes)
+            outfile.write(self.version_message + '\n\n')
+            outfile.write(self.includes + '\n\n')
             outfile.write(self.cread_pars(
-                template_class = 'NSVE<rnumber>::',
+                template_class = '{0}<rnumber>::'.format(self.dns_type),
                 template_prefix = 'template <typename rnumber> ',
-                simname_variable = 'simname.c_str()'))
+                simname_variable = 'simname.c_str()') + '\n\n')
             for rnumber in ['float', 'double']:
                 outfile.write(self.cread_pars(
-                    template_class = 'NSVE<{0}>::'.format(rnumber),
+                    template_class = '{0}<{1}>::'.format(self.dns_type, rnumber),
                     template_prefix = 'template '.format(rnumber),
-                    just_declaration = True))
-            outfile.write(self.main)
+                    just_declaration = True) + '\n\n')
+            outfile.write(self.main + '\n')
         return None
 
-- 
GitLab