Commit 331df8e6 authored by Cristian Lalescu's avatar Cristian Lalescu
Browse files

slight simplification of add_particles usage

parent 118779bc
......@@ -413,17 +413,44 @@ class NavierStokes(bfps.fluid_base.fluid_particle_base):
return None
def add_particles(
self,
integration_steps = [2],
integration_steps = 2,
kcut = None,
interpolator = 'field_interpolator',
frozen_particles = False,
interpolator = ['field_interpolator'],
acc_name = None):
"""
:type integration_steps: int, list of int
:type kcut: None (default), str, list of str
:type interpolator: str, list of str
:type frozen_particles: bool
:type acc_name: str
"""
if self.dtype == np.float32:
FFTW = 'fftwf'
elif self.dtype == np.float64:
FFTW = 'fftw'
s0 = self.particle_species
for s in range(len(integration_steps)):
if type(integration_steps) == int:
integration_steps = [integration_steps]
if type(kcut) == str:
kcut = [kcut]
if type(interpolator) == str:
interpolator = [interpolator]
nspecies = max(len(integration_steps), len(interpolator))
if type(kcut) == list:
nspecies = max(nspecies, len(kcut))
if len(integration_steps) == 1:
integration_steps = [integration_steps[0] for s in range(nspecies)]
if len(interpolator) == 1:
interpolator = [interpolator[0] for s in range(nspecies)]
if type(kcut) == list:
if len(kcut) == 1:
kcut = [kcut[0] for s in range(nspecies)]
assert(len(integration_steps) == nspecies)
assert(len(interpolator) == nspecies)
if type(kcut) == list:
assert(len(kcut) == nspecies)
for s in range(nspecies):
neighbours = self.parameters[interpolator[s] + '_neighbours']
self.parameters['tracers{0}_interpolator'.format(s0 + s)] = interpolator[s]
self.parameters['tracers{0}_integration_steps'.format(s0 + s)] = integration_steps[s]
......@@ -450,7 +477,7 @@ class NavierStokes(bfps.fluid_base.fluid_particle_base):
# must compute acceleration
output_vel_acc += 'double *acceleration = new double[3*nparticles];\n'
output_vel_acc += 'fs->compute_Lagrangian_acceleration({0});\n'.format(acc_name)
for s in range(len(integration_steps)):
for s in range(nspecies):
output_vel_acc += """
{0}->field = fs->rvelocity;
ps{1}->sample_vec_field({0}, velocity);
......@@ -509,7 +536,7 @@ class NavierStokes(bfps.fluid_base.fluid_particle_base):
self.particle_stat_src += (
'if (ps0->iteration % niter_part == 0)\n' +
'{\n')
for s in range(len(integration_steps)):
for s in range(nspecies):
self.particle_start += 'sprintf(fname, "tracers{0}");\n'.format(s0 + s)
self.particle_end += ('ps{0}->write(stat_file);\n' +
'delete ps{0};\n').format(s0 + s)
......@@ -535,7 +562,7 @@ class NavierStokes(bfps.fluid_base.fluid_particle_base):
self.particle_start += output_vel_acc
self.particle_stat_src += output_vel_acc
self.particle_stat_src += '}\n'
self.particle_species += len(integration_steps)
self.particle_species += nspecies
return None
def get_data_file(self):
return h5py.File(os.path.join(self.work_dir, self.simname + '.h5'), 'r')
......
......@@ -107,10 +107,9 @@ def launch(
name = 'spline',
neighbours = opt.neighbours,
smoothness = opt.smoothness)
intsteps = [2, 3, 4, 6]
c.add_particles(
integration_steps = intsteps,
interpolator = ['spline' for i in range(len(intsteps))],
integration_steps = [2, 3, 4, 6],
interpolator = 'spline',
acc_name = 'rFFTW_acc')
#c.add_particle_fields(kcut = 'fs->kM/2', name = 'filtered', neighbours = opt.neighbours)
#c.add_particles(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment