From a2b474e05ae5444dc11567f78b548effc136b259 Mon Sep 17 00:00:00 2001
From: Cristian C Lalescu <Cristian.Lalescu@ds.mpg.de>
Date: Mon, 4 Feb 2019 09:44:55 +0100
Subject: [PATCH] checkpoint

---
 bfps/DNS.py                       | 107 +++++++++++++++---------------
 bfps/test/test_particle_clouds.py |  66 ++++++++++++++++++
 2 files changed, 120 insertions(+), 53 deletions(-)
 create mode 100644 bfps/test/test_particle_clouds.py

diff --git a/bfps/DNS.py b/bfps/DNS.py
index c1835062..3aa47fd4 100644
--- a/bfps/DNS.py
+++ b/bfps/DNS.py
@@ -430,9 +430,7 @@ class DNS(_code):
         return None
     def write_par(
             self,
-            iter0 = 0,
-            particle_ic = None,
-            particles_off = False):
+            iter0 = 0):
         assert (self.parameters['niter_todo'] % self.parameters['niter_stat'] == 0)
         assert (self.parameters['niter_todo'] % self.parameters['niter_out']  == 0)
         assert (self.parameters['niter_out']  % self.parameters['niter_stat'] == 0)
@@ -479,7 +477,7 @@ class DNS(_code):
                                                  4),
                                      dtype = np.int64)
             ofile['checkpoint'] = int(0)
-        if (self.dns_type in ['NSVE', 'NSVE_no_output']) or particles_off:
+        if (self.dns_type in ['NSVE', 'NSVE_no_output']):
             return None
 
         if type(particle_ic) == type(None):
@@ -995,58 +993,61 @@ class DNS(_code):
                         particle_file.create_group('tracers0/pressure_gradient')
                         particle_file.create_group('tracers0/pressure_Hessian')
         return None
-    def launch_jobs(
+    def generate_initial_condition(
             self,
-            opt = None,
-            particle_initial_condition = None):
-        if not os.path.exists(os.path.join(self.work_dir, self.simname + '.h5')):
-            # take care of fields' initial condition
-            # first, check if initial field exists
-            need_field = False
-            if not os.path.exists(self.get_checkpoint_0_fname()):
+            opt = None):
+        # take care of fields' initial condition
+        # first, check if initial field exists
+        need_field = False
+        if not os.path.exists(self.get_checkpoint_0_fname()):
+            need_field = True
+        else:
+            f = h5py.File(self.get_checkpoint_0_fname(), 'r')
+            try:
+                dset = f['vorticity/complex/0']
+                need_field = (dset.shape == (self.parameters['ny'],
+                                             self.parameters['nz'],
+                                             self.parameters['nx']//2+1,
+                                             3))
+            except:
                 need_field = True
+            f.close()
+        if need_field:
+            f = h5py.File(self.get_checkpoint_0_fname(), 'a')
+            if len(opt.src_simname) > 0:
+                source_cp = 0
+                src_file = 'not_a_file'
+                while True:
+                    src_file = os.path.join(
+                        os.path.realpath(opt.src_work_dir),
+                        opt.src_simname + '_checkpoint_{0}.h5'.format(source_cp))
+                    f0 = h5py.File(src_file, 'r')
+                    if '{0}'.format(opt.src_iteration) in f0['vorticity/complex'].keys():
+                        f0.close()
+                        break
+                    source_cp += 1
+                self.copy_complex_field(
+                        src_file,
+                        'vorticity/complex/{0}'.format(opt.src_iteration),
+                        f,
+                        'vorticity/complex/{0}'.format(0))
             else:
-                f = h5py.File(self.get_checkpoint_0_fname(), 'r')
-                try:
-                    dset = f['vorticity/complex/0']
-                    need_field = (dset.shape == (self.parameters['ny'],
-                                                 self.parameters['nz'],
-                                                 self.parameters['nx']//2+1,
-                                                 3))
-                except:
-                    need_field = True
-                f.close()
-            if need_field:
-                f = h5py.File(self.get_checkpoint_0_fname(), 'a')
-                if len(opt.src_simname) > 0:
-                    source_cp = 0
-                    src_file = 'not_a_file'
-                    while True:
-                        src_file = os.path.join(
-                            os.path.realpath(opt.src_work_dir),
-                            opt.src_simname + '_checkpoint_{0}.h5'.format(source_cp))
-                        f0 = h5py.File(src_file, 'r')
-                        if '{0}'.format(opt.src_iteration) in f0['vorticity/complex'].keys():
-                            f0.close()
-                            break
-                        source_cp += 1
-                    self.copy_complex_field(
-                            src_file,
-                            'vorticity/complex/{0}'.format(opt.src_iteration),
-                            f,
-                            'vorticity/complex/{0}'.format(0))
-                else:
-                    data = self.generate_vector_field(
-                           write_to_file = False,
-                           spectra_slope = 2.0,
-                           amplitude = 0.05)
-                    f['vorticity/complex/{0}'.format(0)] = data
-                f.close()
-            self.write_par(
-                    particle_ic = None)
-            # take care of particles' initial condition
-            if self.dns_type in ['NSVEparticles', 'NSVEcomplex_particles', 'NSVEparticles_no_output', 'NSVEp_extra_sampling']:
-                self.generate_particle_data(opt = opt)
+                data = self.generate_vector_field(
+                       write_to_file = False,
+                       spectra_slope = 2.0,
+                       amplitude = 0.05)
+                f['vorticity/complex/{0}'.format(0)] = data
+            f.close()
+        # now take care of particles' initial condition
+        if self.dns_type in ['NSVEparticles', 'NSVEcomplex_particles', 'NSVEparticles_no_output', 'NSVEp_extra_sampling']:
+            self.generate_particle_data(opt = opt)
+        return None
+    def launch_jobs(
+            self,
+            opt = None):
+        if not os.path.exists(self.get_data_file_name()):
+            self.generate_initial_condition()
+        self.write_par()
         self.run(
                 nb_processes = opt.nb_processes,
                 nb_threads_per_process = opt.nb_threads_per_process,
diff --git a/bfps/test/test_particle_clouds.py b/bfps/test/test_particle_clouds.py
new file mode 100644
index 00000000..1a890495
--- /dev/null
+++ b/bfps/test/test_particle_clouds.py
@@ -0,0 +1,66 @@
+#! /usr/bin/env python
+#######################################################################
+#                                                                     #
+#  Copyright 2019 Max Planck Institute                                #
+#                 for Dynamics and Self-Organization                  #
+#                                                                     #
+#  This file is part of bfps.                                         #
+#                                                                     #
+#  bfps is free software: you can redistribute it and/or modify       #
+#  it under the terms of the GNU General Public License as published  #
+#  by the Free Software Foundation, either version 3 of the License,  #
+#  or (at your option) any later version.                             #
+#                                                                     #
+#  bfps is distributed in the hope that it will be useful,            #
+#  but WITHOUT ANY WARRANTY; without even the implied warranty of     #
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      #
+#  GNU General Public License for more details.                       #
+#                                                                     #
+#  You should have received a copy of the GNU General Public License  #
+#  along with bfps.  If not, see <http://www.gnu.org/licenses/>       #
+#                                                                     #
+# Contact: Cristian.Lalescu@ds.mpg.de                                 #
+#                                                                     #
+#######################################################################
+
+
+
+import os
+import numpy as np
+import h5py
+import sys
+
+import bfps
+from bfps import DNS
+
+def main():
+    nclouds = 4
+    nparticles_per_cloud = 3
+    nparticles = nclouds*nparticles_per_cloud
+    niterations = 32
+    c = DNS()
+    ic_file = h5py.File(c.get_checkpoint_0_fname(), 'a')
+    ic_file['tracers0/state/0'] = np.random.random((nclouds, nparticles_per_cloud, 3))
+    ic_file['tracers0/rhs/0'] = np.zeros((2, nclouds, nparticles_per_cloud, 3))
+    ic_file.close()
+    c.launch(
+            ['NSVEparticles',
+             '-n', '32',
+             '--src-simname', 'B32p1e4',
+             '--forcing_type', 'linear',
+             '--src-wd', bfps.lib_dir + '/test',
+             '--src-iteration', '0',
+             '--np', '4',
+             '--ntpp', '1',
+             '--fftw_plan_rigor', 'FFTW_PATIENT',
+             '--niter_todo', '{0}'.format(niterations),
+             '--niter_out', '{0}'.format(niterations),
+             '--niter_stat', '1',
+             '--nparticles', '{0}'.format(nparticles),
+             '--tracers0_integration_steps', '2',
+             '--wd', './'])
+    return None
+
+if __name__ == '__main__':
+    main()
+
-- 
GitLab