Skip to content
Snippets Groups Projects
Commit 1238a169 authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactor

parent 41b3caae
Branches
Tags
1 merge request!9Mpi adder
Pipeline #93660 failed
...@@ -37,6 +37,13 @@ mytest: ...@@ -37,6 +37,13 @@ mytest:
- pytest-3 -q --cov=resolve test.py - pytest-3 -q --cov=resolve test.py
coverage: '/^TOTAL.+?(\d+\%)$/' coverage: '/^TOTAL.+?(\d+\%)$/'
test_mpi:
stage: testing
variables:
OMPI_MCA_btl_vader_single_copy_mechanism: none
script:
- mpiexec -n 2 --bind-to none pytest-3 -q test/test_mpi
# staticchecks: # staticchecks:
# stage: testing # stage: testing
# script: # script:
......
...@@ -5,7 +5,7 @@ ENV DEBIAN_FRONTEND noninteractive ...@@ -5,7 +5,7 @@ ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update -qq && apt-get install -qq git RUN apt-get update -qq && apt-get install -qq git
# Actual dependencies # Actual dependencies
RUN apt-get update -qq && apt-get install -qq python3-pip casacore-dev python3-matplotlib RUN apt-get update -qq && apt-get install -qq python3-pip casacore-dev python3-matplotlib python3-mpi4py
RUN pip3 install scipy git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_7 RUN pip3 install scipy git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_7
RUN pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/ducc.git@ducc0 RUN pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/ducc.git@ducc0
# Optional dependencies # Optional dependencies
......
...@@ -5,6 +5,7 @@ from .global_config import * ...@@ -5,6 +5,7 @@ from .global_config import *
from .likelihood import * from .likelihood import *
from .minimization import Minimization, MinimizationState, simple_minimize from .minimization import Minimization, MinimizationState, simple_minimize
from .mpi import onlymaster from .mpi import onlymaster
from .mpi_operators import *
from .ms_import import ms2observations, ms_n_spectral_windows from .ms_import import ms2observations, ms_n_spectral_windows
from .multi_frequency.irg_space import IRGSpace from .multi_frequency.irg_space import IRGSpace
from .multi_frequency.operators import ( from .multi_frequency.operators import (
...@@ -17,7 +18,7 @@ from .plotter import MfPlotter, Plotter ...@@ -17,7 +18,7 @@ from .plotter import MfPlotter, Plotter
from .points import PointInserter from .points import PointInserter
from .polarization import polarization_matrix_exponential from .polarization import polarization_matrix_exponential
from .primary_beam import vla_beam from .primary_beam import vla_beam
from .response import MfResponse, StokesIResponse, ResponseDistributor from .response import MfResponse, ResponseDistributor, StokesIResponse
from .simple_operators import * from .simple_operators import *
from .util import ( from .util import (
Reshaper, Reshaper,
......
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import nifty7 as ift
class AllreduceSum(ift.Operator):
def __init__(self, oplist, comm, nwork=None):
"""nwork only needed if samples need to be drawn and oplist are EnergyOperators."""
self._oplist, self._comm = oplist, comm
self._domain = self._oplist[0].domain
self._target = self._oplist[0].target
self._nwork = nwork
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(
[op(x) for op in self._oplist], self._comm
)
opx = [op(x) for op in self._oplist]
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if opx[0].metric is None:
return x.new(val, jac)
met = AllreduceSumLinear([lin.metric for lin in opx], self._comm, self._nwork)
return x.new(val, jac, met)
class AllreduceSumLinear(ift.LinearOperator):
def __init__(self, oplist, comm=None, nwork=None):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
self._domain = ift.makeDomain(oplist[0].domain)
self._target = ift.makeDomain(oplist[0].target)
cap = oplist[0]._capability
assert all(oo.domain == self._domain for oo in oplist)
assert all(oo.target == self._target for oo in oplist)
assert all(oo._capability == cap for oo in oplist)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._nwork = nwork
def apply(self, x, mode):
self._check_input(x, mode)
return ift.utilities.allreduce_sum(
[op.apply(x, mode) for op in self._oplist], self._comm
)
def draw_sample(self, from_inverse=False):
size, rank, _ = ift.utilities.get_MPI_params_from_comm(self._comm)
lo, _ = ift.utilities.shareRange(self._nwork, size, rank)
sseq = ift.random.spawn_sseq(self._nwork)
local_samples = []
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[lo + ii]):
local_samples.append(op.draw_sample(from_inverse))
return ift.utilities.allreduce_sum(local_samples, self._comm)
File moved
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
# Copyright(C) 2021 Max-Planck-Society # Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
from functools import reduce from functools import reduce
from operator import add from operator import add
...@@ -8,6 +9,7 @@ from types import GeneratorType ...@@ -8,6 +9,7 @@ from types import GeneratorType
import numpy as np import numpy as np
import nifty7 as ift import nifty7 as ift
import resolve as rve
def getop(comm): def getop(comm):
...@@ -38,73 +40,19 @@ def getop(comm): ...@@ -38,73 +40,19 @@ def getop(comm):
iicc = ift.makeOp(ift.makeField(ddom, invcov[ii])) iicc = ift.makeOp(ift.makeField(ddom, invcov[ii]))
ee = ift.GaussianEnergy(dd, iicc) ee = ift.GaussianEnergy(dd, iicc)
lst.append(ee) lst.append(ee)
op = AllreduceSum(lst, comm, nwork) op = rve.AllreduceSum(lst, comm, nwork)
ift.extra.check_operator(op, ift.from_random(op.domain)) ift.extra.check_operator(op, ift.from_random(op.domain))
sky = ift.FieldAdapter(op.domain, "sky") sky = ift.FieldAdapter(op.domain, "sky")
return op @ sky.exp() return op @ sky.exp()
class AllreduceSum(ift.Operator):
def __init__(self, oplist, comm, nwork=None):
"""nwork only needed if samples need to be drawn and oplist are EnergyOperators."""
self._oplist, self._comm = oplist, comm
self._domain = self._oplist[0].domain
self._target = self._oplist[0].target
self._nwork = nwork
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(
[op(x) for op in self._oplist], self._comm
)
opx = [op(x) for op in self._oplist]
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if opx[0].metric is None:
return x.new(val, jac)
met = AllreduceSumLinear([lin.metric for lin in opx], self._comm, self._nwork)
return x.new(val, jac, met)
class AllreduceSumLinear(ift.LinearOperator):
def __init__(self, oplist, comm=None, nwork=None):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
self._domain = ift.makeDomain(oplist[0].domain)
self._target = ift.makeDomain(oplist[0].target)
cap = oplist[0]._capability
assert all(oo.domain == self._domain for oo in oplist)
assert all(oo.target == self._target for oo in oplist)
assert all(oo._capability == cap for oo in oplist)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._nwork = nwork
def apply(self, x, mode):
self._check_input(x, mode)
return ift.utilities.allreduce_sum(
[op.apply(x, mode) for op in self._oplist], self._comm
)
def draw_sample(self, from_inverse=False):
size, rank, _ = ift.utilities.get_MPI_params_from_comm(self._comm)
lo, _ = ift.utilities.shareRange(self._nwork, size, rank)
sseq = ift.random.spawn_sseq(self._nwork)
local_samples = []
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[lo + ii]):
local_samples.append(op.draw_sample(from_inverse))
return ift.utilities.allreduce_sum(local_samples, self._comm)
def allclose(gen): def allclose(gen):
ref = next(gen) if isinstance(gen, GeneratorType) else gen[0] ref = next(gen) if isinstance(gen, GeneratorType) else gen[0]
for aa in gen: for aa in gen:
ift.extra.assert_allclose(ref, aa) ift.extra.assert_allclose(ref, aa)
def main(): def test_mpi_adder():
ddomain = ift.UnstructuredDomain(4), ift.UnstructuredDomain(1) ddomain = ift.UnstructuredDomain(4), ift.UnstructuredDomain(1)
comm, size, rank, master = ift.utilities.get_MPI_params() comm, size, rank, master = ift.utilities.get_MPI_params()
data = ift.from_random(ddomain) data = ift.from_random(ddomain)
...@@ -168,7 +116,3 @@ def main(): ...@@ -168,7 +116,3 @@ def main():
mini(ift.MetricGaussianKL.make(pos, ham, 3, True))[0].position mini(ift.MetricGaussianKL.make(pos, ham, 3, True))[0].position
) )
allclose(mini_results) allclose(mini_results)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment