Commit 00691967 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge remote-tracking branch 'origin/NIFTy_5' into improving_consistency_checks

parents b82dc0dc 156c9d79
Pipeline #62109 passed with stages
in 7 minutes and 3 seconds
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
git_version.py git_version.py
# custom # custom
*.txt
setup.cfg setup.cfg
.idea .idea
.DS_Store .DS_Store
......
...@@ -10,11 +10,11 @@ RUN apt-get update && apt-get install -y \ ...@@ -10,11 +10,11 @@ RUN apt-get update && apt-get install -y \
# Testing dependencies # Testing dependencies
python3-pytest-cov jupyter \ python3-pytest-cov jupyter \
# Optional NIFTy dependencies # Optional NIFTy dependencies
libfftw3-dev python3-mpi4py python3-matplotlib python3-pynfft \ python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies # more optional NIFTy dependencies
&& pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft.git \
&& pip3 install jupyter \ && pip3 install jupyter \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
......
...@@ -47,9 +47,9 @@ Installation ...@@ -47,9 +47,9 @@ Installation
- [Python 3](https://www.python.org/) (3.5.x or later) - [Python 3](https://www.python.org/) (3.5.x or later)
- [SciPy](https://www.scipy.org/) - [SciPy](https://www.scipy.org/)
- [pypocketfft](https://gitlab.mpcdf.mpg.de/mtr/pypocketfft)
Optional dependencies: Optional dependencies:
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms
- [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic - [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic
transforms involving domains on the sphere) transforms involving domains on the sphere)
- [nifty_gridder](https://gitlab.mpcdf.mpg.de/ift/nifty_gridder) (for radio - [nifty_gridder](https://gitlab.mpcdf.mpg.de/ift/nifty_gridder) (for radio
...@@ -73,35 +73,19 @@ NIFTy5 and its mandatory dependencies can be installed via: ...@@ -73,35 +73,19 @@ NIFTy5 and its mandatory dependencies can be installed via:
sudo apt-get install git python3 python3-pip python3-dev sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5 pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5
pip3 install --user git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft
Plotting support is added via: Plotting support is added via:
sudo apt-get install python3-matplotlib sudo apt-get install python3-matplotlib
NIFTy uses Numpy's FFT implementation by default. For large problems FFTW may be
used because of its higher performance. It can be installed via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To enable FFTW usage in NIFTy, call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Support for spherical harmonic transforms is added via: Support for spherical harmonic transforms is added via:
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
Support for the radio interferometry gridder is added via: Support for the radio interferometry gridder is added via:
pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git
MPI support is added via: MPI support is added via:
......
...@@ -5,13 +5,11 @@ import numpy as np ...@@ -5,13 +5,11 @@ import numpy as np
import nifty5 as ift import nifty5 as ift
ift.fft.enable_fftw()
np.random.seed(40) np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], [] N0s, a0s, b0s, c0s = [], [], [], []
N1s, a1s, b1s, c1s = [], [], [], []
for ii in range(10, 23): for ii in range(10, 26):
nu = 1024 nu = 1024
nv = 1024 nv = 1024
N = int(2**ii) N = int(2**ii)
...@@ -29,34 +27,20 @@ for ii in range(10, 23): ...@@ -29,34 +27,20 @@ for ii in range(10, 23):
img = ift.from_global_data(uvspace, img) img = ift.from_global_data(uvspace, img)
t0 = time() t0 = time()
GM = ift.GridderMaker(uvspace, eps=1e-7) GM = ift.GridderMaker(uvspace, eps=1e-7, uv=uv)
idx = GM.getReordering(uv)
uv = uv[idx]
vis = vis[idx]
vis = ift.from_global_data(visspace, vis) vis = ift.from_global_data(visspace, vis)
op = GM.getFull(uv).adjoint op = GM.getFull().adjoint
t1 = time() t1 = time()
op(img).to_global_data() op(img).to_global_data()
t2 = time() t2 = time()
op.adjoint(vis).to_global_data() op.adjoint(vis).to_global_data()
t3 = time() t3 = time()
print(t2-t1, t3-t2)
N0s.append(N) N0s.append(N)
a0s.append(t1 - t0) a0s.append(t1 - t0)
b0s.append(t2 - t1) b0s.append(t2 - t1)
c0s.append(t3 - t2) c0s.append(t3 - t2)
t0 = time()
op = ift.NFFT(uvspace, uv)
t1 = time()
op(img).to_global_data()
t2 = time()
op.adjoint(vis).to_global_data()
t3 = time()
N1s.append(N)
a1s.append(t1 - t0)
b1s.append(t2 - t1)
c1s.append(t3 - t2)
print('Measure rest operator') print('Measure rest operator')
sc = ift.StatCalculator() sc = ift.StatCalculator()
op = GM.getRest().adjoint op = GM.getRest().adjoint
...@@ -68,10 +52,9 @@ t_fft = sc.mean ...@@ -68,10 +52,9 @@ t_fft = sc.mean
print('FFT shape', res.shape) print('FFT shape', res.shape)
plt.scatter(N0s, a0s, label='Gridder mr') plt.scatter(N0s, a0s, label='Gridder mr')
plt.scatter(N1s, a1s, marker='^', label='NFFT')
plt.legend() plt.legend()
# no idea why this is necessary, but if it is omitted, the range is wrong # no idea why this is necessary, but if it is omitted, the range is wrong
plt.ylim(min(a0s+a1s), max(a0s+a1s)) plt.ylim(min(a0s), max(a0s))
plt.ylabel('time [s]') plt.ylabel('time [s]')
plt.title('Initialization') plt.title('Initialization')
plt.loglog() plt.loglog()
...@@ -79,9 +62,7 @@ plt.savefig('bench0.png') ...@@ -79,9 +62,7 @@ plt.savefig('bench0.png')
plt.close() plt.close()
plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times') plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N1s, b1s, color='r', marker='^', label='NFFT times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times') plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.scatter(N1s, c1s, color='r', label='NFFT adjoint times')
plt.axhline(sc.mean, label='FFT') plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var)) plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var)) plt.axhline(sc.mean - np.sqrt(sc.var))
......
...@@ -103,13 +103,15 @@ if __name__ == '__main__': ...@@ -103,13 +103,15 @@ if __name__ == '__main__':
data = signal_response(mock_position) + N.draw_sample() data = signal_response(mock_position) + N.draw_sample()
# Minimization parameters # Minimization parameters
ic_sampling = ift.GradientNormController(iteration_limit=100) ic_sampling = ift.AbsDeltaEnergyController(
ic_newton = ift.GradInfNormController( name='Sampling', deltaE=0.05, iteration_limit=100)
name='Newton', tol=1e-7, iteration_limit=35) ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.5, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# Set up likelihood and information Hamiltonian # Set up likelihood and information Hamiltonian
likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response) likelihood = ift.GaussianEnergy(mean=data,
inverse_covariance=N.inverse)(signal_response)
H = ift.StandardHamiltonian(likelihood, ic_sampling) H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.) initial_mean = ift.MultiField.full(H.domain, 0.)
......
...@@ -2,6 +2,12 @@ NIFTy-related publications ...@@ -2,6 +2,12 @@ NIFTy-related publications
========================== ==========================
:: ::
@article{asclnifty5,
title={NIFTy5: Numerical Information Field Theory v5},
author={Arras, Philipp and Baltac, Mihai and Ensslin, Torsten A and Frank, Philipp and Hutschenreuter, Sebastian and Knollmueller, Jakob and Leike, Reimar and Newrzella, Max-Niklas and Platz, Lukas and Reinecke, Martin and others},
journal={Astrophysics Source Code Library},
year={2019}
}
@software{nifty, @software{nifty,
author = {{Martin Reinecke, Theo Steininger, Marco Selig}}, author = {{Martin Reinecke, Theo Steininger, Marco Selig}},
...@@ -11,7 +17,7 @@ NIFTy-related publications ...@@ -11,7 +17,7 @@ NIFTy-related publications
date = {2018-04-05}, date = {2018-04-05},
} }
@ARTICLE{2013A&A...554A..26S, @article{2013A&A...554A..26S,
author = {{Selig}, M. and {Bell}, M.~R. and {Junklewitz}, H. and {Oppermann}, N. and {Reinecke}, M. and {Greiner}, M. and {Pachajoa}, C. and {En{\ss}lin}, T.~A.}, author = {{Selig}, M. and {Bell}, M.~R. and {Junklewitz}, H. and {Oppermann}, N. and {Reinecke}, M. and {Greiner}, M. and {Pachajoa}, C. and {En{\ss}lin}, T.~A.},
title = "{NIFTY - Numerical Information Field Theory. A versatile PYTHON library for signal inference}", title = "{NIFTY - Numerical Information Field Theory. A versatile PYTHON library for signal inference}",
journal = {\aap}, journal = {\aap},
...@@ -29,7 +35,7 @@ NIFTy-related publications ...@@ -29,7 +35,7 @@ NIFTy-related publications
adsnote = {Provided by the SAO/NASA Astrophysics Data System} adsnote = {Provided by the SAO/NASA Astrophysics Data System}
} }
@ARTICLE{2017arXiv170801073S, @article{2017arXiv170801073S,
author = {{Steininger}, T. and {Dixit}, J. and {Frank}, P. and {Greiner}, M. and {Hutschenreuter}, S. and {Knollm{\"u}ller}, J. and {Leike}, R. and {Porqueres}, N. and {Pumpe}, D. and {Reinecke}, M. and {{\v S}raml}, M. and {Varady}, C. and {En{\ss}lin}, T.}, author = {{Steininger}, T. and {Dixit}, J. and {Frank}, P. and {Greiner}, M. and {Hutschenreuter}, S. and {Knollm{\"u}ller}, J. and {Leike}, R. and {Porqueres}, N. and {Pumpe}, D. and {Reinecke}, M. and {{\v S}raml}, M. and {Varady}, C. and {En{\ss}lin}, T.},
title = "{NIFTy 3 - Numerical Information Field Theory - A Python framework for multicomponent signal inference on HPC clusters}", title = "{NIFTy 3 - Numerical Information Field Theory - A Python framework for multicomponent signal inference on HPC clusters}",
journal = {ArXiv e-prints}, journal = {ArXiv e-prints},
......
...@@ -9,35 +9,19 @@ NIFTy5 and its mandatory dependencies can be installed via:: ...@@ -9,35 +9,19 @@ NIFTy5 and its mandatory dependencies can be installed via::
sudo apt-get install git python3 python3-pip python3-dev sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5 pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5
pip3 install --user git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft
Plotting support is added via:: Plotting support is added via::
sudo apt-get install python3-matplotlib sudo apt-get install python3-matplotlib
NIFTy uses Numpy's FFT implementation by default. For large problems FFTW may be
used because of its higher performance. It can be installed via::
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To enable FFTW usage in NIFTy, call::
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Support for spherical harmonic transforms is added via:: Support for spherical harmonic transforms is added via::
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
Support for the radio interferometry gridder is added via: Support for the radio interferometry gridder is added via::
pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git
MPI support is added via:: MPI support is added via::
......
...@@ -45,20 +45,22 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator ...@@ -45,20 +45,22 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import ( from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer, VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator) FieldAdapter, ducktape, GeometryRemover, NullOperator,
MatrixProductOperator, PartialExtractor)
from .operators.value_inserter import ValueInserter from .operators.value_inserter import ValueInserter
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, StandardHamiltonian, AveragedEnergy) BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator)
from .operators.convolution_operators import FuncConvolutionOperator from .operators.convolution_operators import FuncConvolutionOperator
from .probing import probe_with_posterior_samples, probe_diagonal, \ from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator StatCalculator, approximation2endo
from .minimization.line_search import LineSearch from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import ( from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController, IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController) GradInfNormController, AbsDeltaEnergyController)
from .minimization.minimizer import Minimizer from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG from .minimization.nonlinear_cg import NonlinearCG
...@@ -86,7 +88,6 @@ from .library.wiener_filter_curvature import WienerFilterCurvature ...@@ -86,7 +88,6 @@ from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.correlated_fields import CorrelatedField, MfCorrelatedField from .library.correlated_fields import CorrelatedField, MfCorrelatedField
from .library.adjust_variances import (make_adjust_variances_hamiltonian, from .library.adjust_variances import (make_adjust_variances_hamiltonian,
do_adjust_variances) do_adjust_variances)
from .library.nfft import NFFT
from .library.gridder import GridderMaker from .library.gridder import GridderMaker
from . import extra from . import extra
...@@ -97,6 +98,8 @@ from .logger import logger ...@@ -97,6 +98,8 @@ from .logger import logger
from .linearization import Linearization from .linearization import Linearization
from .operator_spectrum import operator_spectrum
from . import internal_config from . import internal_config
_scheme = internal_config.parallelization_scheme() _scheme = internal_config.parallelization_scheme()
if _scheme == "Samples": if _scheme == "Samples":
......
...@@ -103,7 +103,9 @@ class GLSpace(StructuredDomain): ...@@ -103,7 +103,9 @@ class GLSpace(StructuredDomain):
The partner domain The partner domain
""" """
from ..domains.lm_space import LMSpace from ..domains.lm_space import LMSpace
return LMSpace(lmax=self._nlat-1, mmax=self._nlon//2) mmax = self._nlon//2
lmax = max(mmax, self._nlat-1)
return LMSpace(lmax=lmax, mmax=mmax)
def check_codomain(self, codomain): def check_codomain(self, codomain):
"""Raises `TypeError` if `codomain` is not a matching partner domain """Raises `TypeError` if `codomain` is not a matching partner domain
......
...@@ -50,6 +50,8 @@ class LogRGSpace(StructuredDomain): ...@@ -50,6 +50,8 @@ class LogRGSpace(StructuredDomain):
self._bindistances = tuple(bindistances) self._bindistances = tuple(bindistances)
self._t_0 = tuple(t_0) self._t_0 = tuple(t_0)
if min(self._bindistances) <= 0:
raise ValueError('Non-positive bindistances encountered')
self._dim = int(reduce(lambda x, y: x*y, self._shape)) self._dim = int(reduce(lambda x, y: x*y, self._shape))
self._dvol = float(reduce(lambda x, y: x*y, self._bindistances)) self._dvol = float(reduce(lambda x, y: x*y, self._bindistances))
...@@ -80,8 +82,8 @@ class LogRGSpace(StructuredDomain): ...@@ -80,8 +82,8 @@ class LogRGSpace(StructuredDomain):
return np.array(self._t_0) return np.array(self._t_0)
def __repr__(self): def __repr__(self):
return ("LogRGSpace(shape={}, harmonic={})".format( return ("LogRGSpace(shape={}, bindistances={}, t_0={}, harmonic={})".format(
self.shape, self.harmonic)) self.shape, self.bindistances, self.t_0, self.harmonic))
def get_default_codomain(self): def get_default_codomain(self):
"""Returns a :class:`LogRGSpace` object representing the (position or """Returns a :class:`LogRGSpace` object representing the (position or
...@@ -91,10 +93,10 @@ class LogRGSpace(StructuredDomain): ...@@ -91,10 +93,10 @@ class LogRGSpace(StructuredDomain):
Returns Returns
------- -------
LogRGSpace LogRGSpace
The parter domain The partner domain
""" """
codomain_bindistances = 1./(self.bindistances*self.shape) codomain_bindistances = 1./(self.bindistances*self.shape)
return LogRGSpace(self.shape, codomain_bindistances, self._t_0, True) return LogRGSpace(self.shape, codomain_bindistances, self._t_0, not self.harmonic)
def get_k_length_array(self): def get_k_length_array(self):
"""Generates array of distances to origin of the space. """Generates array of distances to origin of the space.
......
...@@ -165,6 +165,8 @@ class PowerSpace(StructuredDomain): ...@@ -165,6 +165,8 @@ class PowerSpace(StructuredDomain):
if binbounds is not None: if binbounds is not None:
binbounds = tuple(binbounds) binbounds = tuple(binbounds)
if min(binbounds) < 0:
raise ValueError('Negative binbounds encountered')
key = (harmonic_partner, binbounds) key = (harmonic_partner, binbounds)
if self._powerIndexCache.get(key) is None: if self._powerIndexCache.get(key) is None:
......
...@@ -54,6 +54,8 @@ class RGSpace(StructuredDomain): ...@@ -54,6 +54,8 @@ class RGSpace(StructuredDomain):
if np.isscalar(shape): if np.isscalar(shape):
shape = (shape,) shape = (shape,)
self._shape = tuple(int(i) for i in shape) self._shape = tuple(int(i) for i in shape)
if min(self._shape) < 0:
raise ValueError('Negative number of pixels encountered')
if distances is None: if distances is None:
if self.harmonic: if self.harmonic:
...@@ -66,6 +68,8 @@ class RGSpace(StructuredDomain): ...@@ -66,6 +68,8 @@ class RGSpace(StructuredDomain):
temp = np.empty(len(self.shape), dtype=np.float64) temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances temp[:] = distances
self._distances = tuple(temp) self._distances = tuple(temp)
if min(self._distances) <= 0:
raise ValueError('Non-positive distances encountered')
self._dvol = float(reduce(lambda x, y: x*y, self._distances)) self._dvol = float(reduce(lambda x, y: x*y, self._distances))
self._size = int(reduce(lambda x, y: x*y, self._shape)) self._size = int(reduce(lambda x, y: x*y, self._shape))
...@@ -177,7 +181,7 @@ class RGSpace(StructuredDomain): ...@@ -177,7 +181,7 @@ class RGSpace(StructuredDomain):
Returns Returns
------- -------
RGSpace RGSpace
The parter domain The partner domain
""" """
distances = 1. / (np.array(self.shape)*np.array(self.distances)) distances = 1. / (np.array(self.shape)*np.array(self.distances))
return RGSpace(self.shape, distances, not self.harmonic) return RGSpace(self.shape, distances, not self.harmonic)
......
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
import numpy as np import numpy as np
from .domain_tuple import DomainTuple
from .field import Field from .field import Field
from .linearization import Linearization from .linearization import Linearization
from .multi_domain import MultiDomain
from .operators.linear_operator import LinearOperator from .operators.linear_operator import LinearOperator
from .sugar import from_random from .sugar import from_random
...@@ -68,14 +70,25 @@ def _full_implementation(op, domain_dtype, target_dtype, atol, rtol, ...@@ -68,14 +70,25 @@ def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
def _check_linearity(op, domain_dtype, atol, rtol): def _check_linearity(op, domain_dtype, atol, rtol):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
return
fld1 = from_random("normal", op.domain, dtype=domain_dtype) fld1 = from_random("normal", op.domain, dtype=domain_dtype)
fld2 = from_random("normal", op.domain, dtype=domain_dtype) fld2 = from_random("normal", op.domain, dtype=domain_dtype)
alpha = np.random.random() alpha = np.random.random() # FIXME: this can break badly with MPI!
val1 = op(alpha*fld1+fld2) val1 = op(alpha*fld1+fld2)
val2 = alpha*op(fld1)+op(fld2) val2 = alpha*op(fld1)+op(fld2)
_assert_allclose(val1, val2, atol=atol, rtol=rtol) _assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _domain_check(op):
for dd in [op.domain, op.target]:
if not isinstance(dd, (DomainTuple, MultiDomain)):
raise TypeError(
'The domain and the target of an operator need to',
'be instances of either DomainTuple or MultiDomain.')
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7, only_r_linear=False): atol=0, rtol=1e-7, only_r_linear=False):
""" """
...@@ -109,7 +122,11 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -109,7 +122,11 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
""" """
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.') raise TypeError('This test tests only linear operators.')
_domain_check(op)
_check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
_check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol, _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear) only_r_linear)
_full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol, _full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol,
...@@ -180,6 +197,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100): ...@@ -180,6 +197,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
tol : float tol : float
Tolerance for the check. Tolerance for the check.
""" """
_domain_check(op)
for _ in range(ntries): for _ in range(ntries):
lin = op(Linearization.make_var(loc)) lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin) loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
...@@ -15,152 +15,28 @@ ...@@ -15,152 +15,28 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.