Commit 74fdeb6b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

simplify

parent 1a01cba8
......@@ -89,7 +89,7 @@ if __name__ == '__main__':
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png")
# number of samples used to estimate the KL
N_samples = 20
N_samples = 1
for i in range(2):
# KL = ift.KL_Energy(position, H, N_samples)
KL = ift.KL_Energy_MPI(position, H, N_samples, want_metric=True)
......
......@@ -54,12 +54,6 @@ class KL_Energy_MPI(Energy):
self._nsamp = nsamp
self._constants = constants
self._want_metric = want_metric
if nsamp < ntask:
# FIXME We need a better solution here. It is probably not good if
# the script just dies. Can we proceed anyways?
print('Number of samples: {}, number of MPI tasks: {}'.format(
nsamp, ntask))
raise RuntimeError('Cannot use more tasks than samples.')
if _samples is None:
lo, hi = _shareRange(nsamp, ntask, rank)
met = h(Linearization.make_var(position, True)).metric
......@@ -70,7 +64,10 @@ class KL_Energy_MPI(Energy):
self._samples = tuple(_samples)
self._lin = Linearization.make_partial_var(position, constants,
want_metric)
mymap = map(lambda v: self._h(self._lin + v), self._samples)
if len(self._samples) == 0: # hack if there are too many MPI tasks
mymap = map(lambda v: 0*self._h(v), (self._lin,))
else:
mymap = map(lambda v: self._h(self._lin + v), self._samples)
tmp = utilities.my_sum(mymap)*(1./self._nsamp)
self._val = np_allreduce_sum(tmp.val.local_data)[()]
self._grad = allreduce_sum_field(tmp.gradient)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment