From adb19b3561bf836034e68cfba91d934995b3a644 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Wed, 13 Jun 2018 19:33:41 +0200 Subject: [PATCH] some cosmetic fixes and unrelated tweaks --- nifty4/data_objects/distributed_do.py | 1 + nifty4/field.py | 1 + .../library/nonlinear_wiener_filter_energy.py | 5 +++-- nifty4/library/wiener_filter_curvature.py | 3 +-- nifty4/library/wiener_filter_energy.py | 3 ++- nifty4/logger.py | 1 + nifty4/minimization/line_search.py | 3 ++- .../minimization/line_search_strong_wolfe.py | 3 ++- nifty4/multi/multi_field.py | 1 + nifty4/sugar.py | 1 + nifty4/utilities.py | 2 ++ test/test_field.py | 4 ++-- test/test_minimization/test_minimizers.py | 11 +++++----- test/test_operators/test_composed_operator.py | 20 +++++++++---------- test/test_operators/test_diagonal_operator.py | 4 ++-- 15 files changed, 36 insertions(+), 27 deletions(-) diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py index 14d132d36..393e22a94 100644 --- a/nifty4/data_objects/distributed_do.py +++ b/nifty4/data_objects/distributed_do.py @@ -211,6 +211,7 @@ class data_object(object): def fill(self, value): self._data.fill(value) + for op in ["__add__", "__radd__", "__iadd__", "__sub__", "__rsub__", "__isub__", "__mul__", "__rmul__", "__imul__", diff --git a/nifty4/field.py b/nifty4/field.py index 9ff37ff99..6438f31fa 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -730,6 +730,7 @@ class Field(object): return False return (self._val == other._val).all() + for op in ["__add__", "__radd__", "__iadd__", "__sub__", "__rsub__", "__isub__", "__mul__", "__rmul__", "__imul__", diff --git a/nifty4/library/nonlinear_wiener_filter_energy.py b/nifty4/library/nonlinear_wiener_filter_energy.py index 93786f7df..c0ecf3ef5 100644 --- a/nifty4/library/nonlinear_wiener_filter_energy.py +++ b/nifty4/library/nonlinear_wiener_filter_energy.py @@ -38,7 +38,7 @@ class NonlinearWienerFilterEnergy(Energy): self.N = N self.S = S self.inverter = inverter - if sampling_inverter==None: + if sampling_inverter is None: sampling_inverter = inverter self.sampling_inverter = sampling_inverter t1 = S.inverse_times(position) @@ -64,4 +64,5 @@ class NonlinearWienerFilterEnergy(Energy): @property @memo def curvature(self): - return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, self.sampling_inverter) + return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, + self.sampling_inverter) diff --git a/nifty4/library/wiener_filter_curvature.py b/nifty4/library/wiener_filter_curvature.py index e7e2a725c..2b9ab4c3b 100644 --- a/nifty4/library/wiener_filter_curvature.py +++ b/nifty4/library/wiener_filter_curvature.py @@ -45,9 +45,8 @@ def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None): default: None """ M = SandwichOperator.make(R, N.inverse) - if sampling_inverter != None: + if sampling_inverter is not None: op = SamplingEnabler(M, S.inverse, sampling_inverter) else: op = M + S.inverse return InversionEnabler(op, inverter, S.inverse) - diff --git a/nifty4/library/wiener_filter_energy.py b/nifty4/library/wiener_filter_energy.py index 5f25953d9..276aa71e2 100644 --- a/nifty4/library/wiener_filter_energy.py +++ b/nifty4/library/wiener_filter_energy.py @@ -20,7 +20,8 @@ from ..minimization.quadratic_energy import QuadraticEnergy from .wiener_filter_curvature import WienerFilterCurvature -def WienerFilterEnergy(position, d, R, N, S, inverter=None, sampling_inverter=None): +def WienerFilterEnergy(position, d, R, N, S, inverter=None, + sampling_inverter=None): """The Energy for the Wiener filter. It covers the case of linear measurement with diff --git a/nifty4/logger.py b/nifty4/logger.py index d49dd42ee..258f81ea9 100644 --- a/nifty4/logger.py +++ b/nifty4/logger.py @@ -30,4 +30,5 @@ def _logger_init(): res.addHandler(logging.NullHandler()) return res + logger = _logger_init() diff --git a/nifty4/minimization/line_search.py b/nifty4/minimization/line_search.py index c2933416c..49871dc09 100644 --- a/nifty4/minimization/line_search.py +++ b/nifty4/minimization/line_search.py @@ -21,7 +21,8 @@ from ..utilities import NiftyMetaBase class LineSearch(NiftyMetaBase()): - """Class for determining the optimal step size along some descent direction. + """Class for determining the optimal step size along some descent + direction. Parameters ---------- diff --git a/nifty4/minimization/line_search_strong_wolfe.py b/nifty4/minimization/line_search_strong_wolfe.py index 284455d67..b51cb102e 100644 --- a/nifty4/minimization/line_search_strong_wolfe.py +++ b/nifty4/minimization/line_search_strong_wolfe.py @@ -25,7 +25,8 @@ from ..logger import logger class LineSearchStrongWolfe(LineSearch): - """Class for finding a step size that satisfies the strong Wolfe conditions. + """Class for finding a step size that satisfies the strong Wolfe + conditions. Algorithm contains two stages. It begins with a trial step length and keeps increasing it until it finds an acceptable step length or an diff --git a/nifty4/multi/multi_field.py b/nifty4/multi/multi_field.py index 9d52cc861..cce6c79a1 100644 --- a/nifty4/multi/multi_field.py +++ b/nifty4/multi/multi_field.py @@ -160,6 +160,7 @@ class MultiField(object): return False return True + for op in ["__add__", "__radd__", "__iadd__", "__sub__", "__rsub__", "__isub__", "__mul__", "__rmul__", "__imul__", diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 3775baa92..4afa3a98d 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -243,6 +243,7 @@ def makeOp(input): # Arithmetic functions working on Fields + _current_module = sys.modules[__name__] for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: diff --git a/nifty4/utilities.py b/nifty4/utilities.py index b52975bed..8068c4ebe 100644 --- a/nifty4/utilities.py +++ b/nifty4/utilities.py @@ -166,6 +166,8 @@ def nthreads(): import os nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1")) return nthreads._val + + nthreads._val = None # Optional extra arguments for the FFT calls diff --git a/test/test_field.py b/test/test_field.py index cab8f64c5..dc0d6adfa 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -155,9 +155,9 @@ class Test_Functionality(unittest.TestCase): def test_empty_domain(self): f = ift.Field((), 5) - assert_equal(f.to_global_data(), 5) + assert_equal(f.local_data, 5) f = ift.Field(None, 5) - assert_equal(f.to_global_data(), 5) + assert_equal(f.local_data, 5) assert_equal(f.empty_copy().domain, f.domain) assert_equal(f.empty_copy().dtype, f.dtype) assert_equal(f.copy().domain, f.domain) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 378048e1f..8fac6a339 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -65,8 +65,8 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), - 1./covariance_diagonal.to_global_data(), + assert_allclose(energy.position.local_data, + 1./covariance_diagonal.local_data, rtol=1e-3, atol=1e-3) @expand(product(minimizers+newton_minimizers)) @@ -129,7 +129,7 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), 1., + assert_allclose(energy.position.local_data, 1., rtol=1e-3, atol=1e-3) @expand(product(minimizers+slow_minimizers)) @@ -167,7 +167,7 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), 0., + assert_allclose(energy.position.local_data, 0., atol=1e-3) @expand(product(minimizers+newton_minimizers+slow_minimizers)) @@ -205,5 +205,4 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), 0., - atol=1e-3) + assert_allclose(energy.position.local_data, 0., atol=1e-3) diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py index 12b803ad2..8d4271737 100644 --- a/test/test_operators/test_composed_operator.py +++ b/test/test_operators/test_composed_operator.py @@ -58,34 +58,34 @@ class ComposedOperator_Tests(unittest.TestCase): rand1 = ift.Field.from_random('normal', domain=(space1, space2)) tt1 = op.inverse_times(op.times(rand1)) - assert_allclose(tt1.to_global_data(), rand1.to_global_data()) + assert_allclose(tt1.local_data, rand1.local_data) @expand(product(spaces)) def test_sum(self, space): - op1 = ift.DiagonalOperator(ift.Field.full(space, 2.)) - op2 = ift.ScalingOperator(3., space) + op1 = ift.makeOp(ift.Field.full(space, 2.)) + op2 = 3. full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2 x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) - assert_allclose(res.to_global_data(), 11.) + assert_allclose(res.local_data, 11.) @expand(product(spaces)) def test_chain(self, space): - op1 = ift.DiagonalOperator(ift.Field.full(space, 2.)) - op2 = ift.ScalingOperator(3., space) + op1 = ift.makeOp(ift.Field.full(space, 2.)) + op2 = 3. full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2 x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) - assert_allclose(res.to_global_data(), 432.) + assert_allclose(res.local_data, 432.) @expand(product(spaces)) def test_mix(self, space): - op1 = ift.DiagonalOperator(ift.Field.full(space, 2.)) - op2 = ift.ScalingOperator(3., space) + op1 = ift.makeOp(ift.Field.full(space, 2.)) + op2 = 3. full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2 x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) - assert_allclose(res.to_global_data(), 42.) + assert_allclose(res.local_data, 42.) diff --git a/test/test_operators/test_diagonal_operator.py b/test/test_operators/test_diagonal_operator.py index 5569c7617..6e7210c36 100644 --- a/test/test_operators/test_diagonal_operator.py +++ b/test/test_operators/test_diagonal_operator.py @@ -52,7 +52,7 @@ class DiagonalOperator_Tests(unittest.TestCase): diag = ift.Field.from_random('normal', domain=space) D = ift.DiagonalOperator(diag) tt1 = D.times(D.inverse_times(rand1)) - assert_allclose(rand1.to_global_data(), tt1.to_global_data()) + assert_allclose(rand1.local_data, tt1.local_data) @expand(product(spaces)) def test_times(self, space): @@ -91,4 +91,4 @@ class DiagonalOperator_Tests(unittest.TestCase): diag = ift.Field.from_random('normal', domain=space) D = ift.DiagonalOperator(diag) diag_op = D(ift.Field.full(space, 1.)) - assert_allclose(diag.to_global_data(), diag_op.to_global_data()) + assert_allclose(diag.local_data, diag_op.local_data) -- GitLab