Skip to content
Snippets Groups Projects
Commit 02eb292a authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics

parent b5c5eab4
No related branches found
No related tags found
No related merge requests found
...@@ -15,25 +15,39 @@ _comm = MPI.COMM_WORLD ...@@ -15,25 +15,39 @@ _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
rank = _comm.Get_rank() rank = _comm.Get_rank()
master = (rank == 0) master = (rank == 0)
def _shareRange(nwork, nshares, myshare): def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares nbase = nwork//nshares
additional = nwork % nshares additional = nwork % nshares
lo = myshare*nbase + min(myshare, additional) lo = myshare*nbase + min(myshare, additional)
hi = lo + nbase + int(myshare < additional) hi = lo + nbase + int(myshare < additional)
return lo, hi return lo, hi
def np_allreduce_sum(arr): def np_allreduce_sum(arr):
res = np.empty_like(arr) res = np.empty_like(arr)
_comm.Allreduce(arr, res, MPI.SUM) _comm.Allreduce(arr, res, MPI.SUM)
return res return res
def allreduce_sum_field(fld): def allreduce_sum_field(fld):
if isinstance(fld, Field): if isinstance(fld, Field):
return Field.from_local_data(fld.domain, np_allreduce_sum(fld.local_data)) return Field.from_local_data(fld.domain,
res = tuple(Field.from_local_data(f.domain, np_allreduce_sum(f.local_data)) for f in fld.values()) np_allreduce_sum(fld.local_data))
res = tuple(
Field.from_local_data(f.domain, np_allreduce_sum(f.local_data))
for f in fld.values())
return MultiField(fld.domain, res) return MultiField(fld.domain, res)
class KL_Energy(Energy): class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], _samples=None, def __init__(self,
position,
h,
nsamp,
constants=[],
_samples=None,
want_metric=False): want_metric=False):
super(KL_Energy, self).__init__(position) super(KL_Energy, self).__init__(position)
self._h = h self._h = h
...@@ -51,12 +65,14 @@ class KL_Energy(Energy): ...@@ -51,12 +65,14 @@ class KL_Energy(Energy):
if len(constants) == 0: if len(constants) == 0:
tmp = Linearization.make_var(position, want_metric) tmp = Linearization.make_var(position, want_metric)
else: else:
ops = [ScalingOperator(0. if key in constants else 1., dom) ops = [
for key, dom in position.domain.items()] ScalingOperator(0. if key in constants else 1., dom)
for key, dom in position.domain.items()
]
bdop = BlockDiagonalOperator(position.domain, tuple(ops)) bdop = BlockDiagonalOperator(position.domain, tuple(ops))
tmp = Linearization(position, bdop, want_metric=want_metric) tmp = Linearization(position, bdop, want_metric=want_metric)
mymap = map(lambda v: self._h(tmp+v), self._samples) mymap = map(lambda v: self._h(tmp + v), self._samples)
tmp = utilities.my_sum(mymap) * (1./self._nsamp) tmp = utilities.my_sum(mymap)*(1./self._nsamp)
self._val = np_allreduce_sum(tmp.val.local_data)[()] self._val = np_allreduce_sum(tmp.val.local_data)[()]
self._grad = allreduce_sum_field(tmp.gradient) self._grad = allreduce_sum_field(tmp.gradient)
self._metric = tmp.metric self._metric = tmp.metric
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment