Commit 74fdeb6b by Martin Reinecke

### simplify

parent 1a01cba8
 ... ... @@ -89,7 +89,7 @@ if __name__ == '__main__': plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") # number of samples used to estimate the KL N_samples = 20 N_samples = 1 for i in range(2): # KL = ift.KL_Energy(position, H, N_samples) KL = ift.KL_Energy_MPI(position, H, N_samples, want_metric=True) ... ...
 ... ... @@ -54,12 +54,6 @@ class KL_Energy_MPI(Energy): self._nsamp = nsamp self._constants = constants self._want_metric = want_metric if nsamp < ntask: # FIXME We need a better solution here. It is probably not good if # the script just dies. Can we proceed anyways? print('Number of samples: {}, number of MPI tasks: {}'.format( nsamp, ntask)) raise RuntimeError('Cannot use more tasks than samples.') if _samples is None: lo, hi = _shareRange(nsamp, ntask, rank) met = h(Linearization.make_var(position, True)).metric ... ... @@ -70,7 +64,10 @@ class KL_Energy_MPI(Energy): self._samples = tuple(_samples) self._lin = Linearization.make_partial_var(position, constants, want_metric) mymap = map(lambda v: self._h(self._lin + v), self._samples) if len(self._samples) == 0: # hack if there are too many MPI tasks mymap = map(lambda v: 0*self._h(v), (self._lin,)) else: mymap = map(lambda v: self._h(self._lin + v), self._samples) tmp = utilities.my_sum(mymap)*(1./self._nsamp) self._val = np_allreduce_sum(tmp.val.local_data)[()] self._grad = allreduce_sum_field(tmp.gradient) ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!