From 98c3eb2f2e9b12ab4f1dddef90de3bae61065ae1 Mon Sep 17 00:00:00 2001
From: Cristian C Lalescu <Cristian.Lalescu@ds.mpg.de>
Date: Mon, 14 Aug 2017 16:54:48 +0200
Subject: [PATCH] make the DNS class more general

---
 bfps/DNS.py                          | 33 ++++++++++++++++++----------
 bfps/cpp/full_code/NSVEparticles.cpp |  2 +-
 2 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/bfps/DNS.py b/bfps/DNS.py
index 4f26b86c..57dc879a 100644
--- a/bfps/DNS.py
+++ b/bfps/DNS.py
@@ -78,7 +78,8 @@ class DNS(_code):
             self.C_field_dtype = 'double'
             self.fluid_precision = 'double'
         return None
-    def write_src(self):
+    def write_src(
+            self):
         self.version_message = (
                 '/***********************************************************************\n' +
                 '* this code automatically generated by bfps\n' +
@@ -623,7 +624,8 @@ class DNS(_code):
         return None
     def prepare_launch(
             self,
-            args = []):
+            args = [],
+            extra_parameters = None):
         """Set up reasonable parameters.
 
         With the default Lundgren forcing applied in the band [2, 4],
@@ -657,6 +659,10 @@ class DNS(_code):
         if self.dns_type in ['NSVEparticles', 'NSVEparticles_no_output']:
             for k in self.NSVEp_extra_parameters.keys():
                 self.parameters[k] = self.NSVEp_extra_parameters[k]
+        if type(extra_parameters) != type(None):
+            if self.dns_type in extra_parameters.keys():
+                for k in extra_parameters[self.dns_type].keys():
+                    self.parameters[k] = extra_parameters[self.dns_type][k]
         self.parameters['nu'] = (opt.kMeta * 2 / opt.n)**(4./3)
         self.parameters['dt'] = (opt.dtfactor / opt.n)
         # custom famplitude for 288 and 576
@@ -837,12 +843,10 @@ class DNS(_code):
                 for kz in range(src_file[src_dset_name].shape[0]):
                     dst_file[dst_dset_name][kz] = src_file[src_dset_name][kz]
         else:
-            print('aloha')
             min_shape = (min(dst_shape[0], src_file[src_dset_name].shape[0]),
                          min(dst_shape[1], src_file[src_dset_name].shape[1]),
                          min(dst_shape[2], src_file[src_dset_name].shape[2]),
                          3)
-            print(self.ctype)
             dst_file.create_dataset(
                     dst_dset_name,
                     shape = dst_shape,
@@ -852,6 +856,18 @@ class DNS(_code):
                 dst_file[dst_dset_name][kz,:min_shape[1], :min_shape[2]] = \
                         src_file[src_dset_name][kz, :min_shape[1], :min_shape[2]]
         return None
+    def generate_particle_data(
+            self,
+            opt = None):
+        if self.parameters['nparticles'] > 0:
+            self.generate_tracer_state(
+                    species = 0,
+                    rseed = opt.particle_rand_seed)
+            if not os.path.exists(self.get_particle_file_name()):
+                with h5py.File(self.get_particle_file_name(), 'w') as particle_file:
+                    particle_file.create_group('tracers0/velocity')
+                    particle_file.create_group('tracers0/acceleration')
+        return None
     def launch_jobs(
             self,
             opt = None,
@@ -911,14 +927,7 @@ class DNS(_code):
             self.write_par(
                     particle_ic = None)
             if self.dns_type in ['NSVEparticles', 'NSVEparticles_no_output']:
-                if self.parameters['nparticles'] > 0:
-                    self.generate_tracer_state(
-                            species = 0,
-                            rseed = opt.particle_rand_seed)
-                    if not os.path.exists(self.get_particle_file_name()):
-                        with h5py.File(self.get_particle_file_name(), 'w') as particle_file:
-                            particle_file.create_group('tracers0/velocity')
-                            particle_file.create_group('tracers0/acceleration')
+                self.generate_particle_data(opt = opt)
         self.run(
                 nb_processes = opt.nb_processes,
                 nb_threads_per_process = opt.nb_threads_per_process,
diff --git a/bfps/cpp/full_code/NSVEparticles.cpp b/bfps/cpp/full_code/NSVEparticles.cpp
index ba84b394..90b948b6 100644
--- a/bfps/cpp/full_code/NSVEparticles.cpp
+++ b/bfps/cpp/full_code/NSVEparticles.cpp
@@ -58,9 +58,9 @@ int NSVEparticles<rnumber>::write_checkpoint(void)
 template <typename rnumber>
 int NSVEparticles<rnumber>::finalize(void)
 {
-    this->NSVE<rnumber>::finalize();
     this->ps.release();
     delete this->particles_output_writer_mpi;
+    this->NSVE<rnumber>::finalize();
     return EXIT_SUCCESS;
 }
 
-- 
GitLab