From b39b430f232220cf81b806f0e3d602ec354cdde5 Mon Sep 17 00:00:00 2001
From: Cristian C Lalescu <Cristian.Lalescu@mpcdf.mpg.de>
Date: Tue, 18 Feb 2025 16:32:04 +0100
Subject: [PATCH] adds mechanism for control of initial condition behavior

---
 TurTLE/DNS.py | 76 +++++++++++++++++++++++++++++++--------------------
 1 file changed, 46 insertions(+), 30 deletions(-)

diff --git a/TurTLE/DNS.py b/TurTLE/DNS.py
index 94a2091f..28a8ca37 100644
--- a/TurTLE/DNS.py
+++ b/TurTLE/DNS.py
@@ -929,40 +929,56 @@ class DNS(_code):
         return None
     def generate_initial_condition(
             self,
-            opt = None):
+            opt = None,
+            need_field = None,
+            checkpoint_field = None):
+        """Creates the "checkpoint 0" file.
+
+        Keyword arguments:
+        opt              --- options object, this method may access any of
+                                * opt.src_simname,
+                                * opt.src_work_dir,
+                                * opt.src_iteration
+        need_field       --- boolean: does the DNS need the "0" iteration field
+                             to be present in the "checkpoint 0" file?
+        checkpoint_field --- string: name of field used for checkpoints, one of
+                             'velocity' or 'vorticity'.
+        """
         # take care of fields' initial condition
+        if type(checkpoint_field) == type(None):
+            if self.dns_type in ['NSE', 'NSE_alt_dealias']:
+                checkpoint_field = 'velocity'
+            else:
+                checkpoint_field = 'vorticity'
         # first, check if initial field exists
-        need_field = False
-        if self.check_current_vorticity_exists:
-            need_field = True
-        if self.dns_type in ['NSE', 'NSE_alt_dealias']:
-            checkpoint_field = 'velocity'
-        else:
-            checkpoint_field = 'vorticity'
-        if self.dns_type in [
-                'NSE',
-                'NSE_alt_dealias',
-                'NSVE',
-                'NSVE_no_output',
-                'static_field',
-                'NSVEparticles',
-                'NSVEcomplex_particles',
-                'NSVE_Stokes_particles',
-                'NSVEparticles_no_output',
-                'NSVEp_extra_sampling']:
-            if not os.path.exists(self.get_checkpoint_0_fname()):
+        if type(need_field) == type(None):
+            need_field = False
+            if self.check_current_vorticity_exists:
                 need_field = True
-            else:
-                f = h5py.File(self.get_checkpoint_0_fname(), 'r')
-                try:
-                    dset = f[checkpoint_field + '/complex/0']
-                    need_field = (dset.shape != (self.parameters['ny'],
-                                                 self.parameters['nz'],
-                                                 self.parameters['nx']//2+1,
-                                                 3))
-                except:
+            if self.dns_type in [
+                    'NSE',
+                    'NSE_alt_dealias',
+                    'NSVE',
+                    'NSVE_no_output',
+                    'static_field',
+                    'NSVEparticles',
+                    'NSVEcomplex_particles',
+                    'NSVE_Stokes_particles',
+                    'NSVEparticles_no_output',
+                    'NSVEp_extra_sampling']:
+                if not os.path.exists(self.get_checkpoint_0_fname()):
                     need_field = True
-                f.close()
+                else:
+                    f = h5py.File(self.get_checkpoint_0_fname(), 'r')
+                    try:
+                        dset = f[checkpoint_field + '/complex/0']
+                        need_field = (dset.shape != (self.parameters['ny'],
+                                                     self.parameters['nz'],
+                                                     self.parameters['nx']//2+1,
+                                                     3))
+                    except:
+                        need_field = True
+                    f.close()
         if need_field:
             # sanity check. User cannot demand a random initial condition and a source simulation at the same time.
             assert ((len(opt.src_simname) == 0) or (self.parameters['field_random_seed'] == 0)), \
-- 
GitLab