diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py index 14d132d36ef0df546dfb09a2f9d8fbef28377ddf..393e22a94cf68318cf7749a9558c8f514856462e 100644 --- a/nifty4/data_objects/distributed_do.py +++ b/nifty4/data_objects/distributed_do.py @@ -211,6 +211,7 @@ class data_object(object): def fill(self, value): self._data.fill(value) + for op in ["__add__", "__radd__", "__iadd__", "__sub__", "__rsub__", "__isub__", "__mul__", "__rmul__", "__imul__", diff --git a/nifty4/field.py b/nifty4/field.py index 9ff37ff9997f79d6bad20c41035a8a840afdc8d8..6438f31fa675381ca842b0602152336a54b2b155 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -730,6 +730,7 @@ class Field(object): return False return (self._val == other._val).all() + for op in ["__add__", "__radd__", "__iadd__", "__sub__", "__rsub__", "__isub__", "__mul__", "__rmul__", "__imul__", diff --git a/nifty4/library/nonlinear_wiener_filter_energy.py b/nifty4/library/nonlinear_wiener_filter_energy.py index 93786f7df4de32ec24d85225f25c8a01e7388dff..c0ecf3ef5288e43ba62a1f58d4a530a117718b75 100644 --- a/nifty4/library/nonlinear_wiener_filter_energy.py +++ b/nifty4/library/nonlinear_wiener_filter_energy.py @@ -38,7 +38,7 @@ class NonlinearWienerFilterEnergy(Energy): self.N = N self.S = S self.inverter = inverter - if sampling_inverter==None: + if sampling_inverter is None: sampling_inverter = inverter self.sampling_inverter = sampling_inverter t1 = S.inverse_times(position) @@ -64,4 +64,5 @@ class NonlinearWienerFilterEnergy(Energy): @property @memo def curvature(self): - return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, self.sampling_inverter) + return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, + self.sampling_inverter) diff --git a/nifty4/library/wiener_filter_curvature.py b/nifty4/library/wiener_filter_curvature.py index e7e2a725cf6e6ec9f6a6c8b570ad19721542e2eb..2b9ab4c3bcab183ebe3bac13b32d1e621e14b533 100644 --- a/nifty4/library/wiener_filter_curvature.py +++ b/nifty4/library/wiener_filter_curvature.py @@ -45,9 +45,8 @@ def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None): default: None """ M = SandwichOperator.make(R, N.inverse) - if sampling_inverter != None: + if sampling_inverter is not None: op = SamplingEnabler(M, S.inverse, sampling_inverter) else: op = M + S.inverse return InversionEnabler(op, inverter, S.inverse) - diff --git a/nifty4/library/wiener_filter_energy.py b/nifty4/library/wiener_filter_energy.py index 5f25953d92605ba6bb4011fbf18a37555a61d449..276aa71e22953e908045491cf15fce2f2ec71137 100644 --- a/nifty4/library/wiener_filter_energy.py +++ b/nifty4/library/wiener_filter_energy.py @@ -20,7 +20,8 @@ from ..minimization.quadratic_energy import QuadraticEnergy from .wiener_filter_curvature import WienerFilterCurvature -def WienerFilterEnergy(position, d, R, N, S, inverter=None, sampling_inverter=None): +def WienerFilterEnergy(position, d, R, N, S, inverter=None, + sampling_inverter=None): """The Energy for the Wiener filter. It covers the case of linear measurement with diff --git a/nifty4/logger.py b/nifty4/logger.py index d49dd42ee6a510622691a5db582b56d7189c3d01..258f81ea97c1b5e466279af4276c2eb6fea9a201 100644 --- a/nifty4/logger.py +++ b/nifty4/logger.py @@ -30,4 +30,5 @@ def _logger_init(): res.addHandler(logging.NullHandler()) return res + logger = _logger_init() diff --git a/nifty4/minimization/line_search.py b/nifty4/minimization/line_search.py index c2933416c94e39f075e8e4ed2a08768c80d3b661..49871dc0974ca09eff0db7e3358568c8272257ed 100644 --- a/nifty4/minimization/line_search.py +++ b/nifty4/minimization/line_search.py @@ -21,7 +21,8 @@ from ..utilities import NiftyMetaBase class LineSearch(NiftyMetaBase()): - """Class for determining the optimal step size along some descent direction. + """Class for determining the optimal step size along some descent + direction. Parameters ---------- diff --git a/nifty4/minimization/line_search_strong_wolfe.py b/nifty4/minimization/line_search_strong_wolfe.py index 284455d67d910136fe6e58ed62ea33e3d3834dda..b51cb102eea38a677678df43247d3507960895e6 100644 --- a/nifty4/minimization/line_search_strong_wolfe.py +++ b/nifty4/minimization/line_search_strong_wolfe.py @@ -25,7 +25,8 @@ from ..logger import logger class LineSearchStrongWolfe(LineSearch): - """Class for finding a step size that satisfies the strong Wolfe conditions. + """Class for finding a step size that satisfies the strong Wolfe + conditions. Algorithm contains two stages. It begins with a trial step length and keeps increasing it until it finds an acceptable step length or an diff --git a/nifty4/multi/multi_field.py b/nifty4/multi/multi_field.py index 9d52cc8612314f2383ae561a7fb0f10461152b8e..cce6c79a1171986071caee5a1ef88c969a45d49d 100644 --- a/nifty4/multi/multi_field.py +++ b/nifty4/multi/multi_field.py @@ -160,6 +160,7 @@ class MultiField(object): return False return True + for op in ["__add__", "__radd__", "__iadd__", "__sub__", "__rsub__", "__isub__", "__mul__", "__rmul__", "__imul__", diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 3775baa922819786455667c09e05cafd253f4293..4afa3a98dcc0c0b21530dd2efd3b4105a16cd856 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -243,6 +243,7 @@ def makeOp(input): # Arithmetic functions working on Fields + _current_module = sys.modules[__name__] for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: diff --git a/nifty4/utilities.py b/nifty4/utilities.py index b52975bed5d56fadf724d058087c4524fbcf0a36..8068c4ebef9e024068eeab1642552db531d24171 100644 --- a/nifty4/utilities.py +++ b/nifty4/utilities.py @@ -166,6 +166,8 @@ def nthreads(): import os nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1")) return nthreads._val + + nthreads._val = None # Optional extra arguments for the FFT calls diff --git a/test/test_field.py b/test/test_field.py index cab8f64c51529d6a682f749632571e6293b4b443..dc0d6adfa89bec136974050cf4a0d11849106f97 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -155,9 +155,9 @@ class Test_Functionality(unittest.TestCase): def test_empty_domain(self): f = ift.Field((), 5) - assert_equal(f.to_global_data(), 5) + assert_equal(f.local_data, 5) f = ift.Field(None, 5) - assert_equal(f.to_global_data(), 5) + assert_equal(f.local_data, 5) assert_equal(f.empty_copy().domain, f.domain) assert_equal(f.empty_copy().dtype, f.dtype) assert_equal(f.copy().domain, f.domain) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 378048e1f50109ec605d2b88294852f36ab6687e..8fac6a339682b2258f6e1e2af6b807f2f8e85645 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -65,8 +65,8 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), - 1./covariance_diagonal.to_global_data(), + assert_allclose(energy.position.local_data, + 1./covariance_diagonal.local_data, rtol=1e-3, atol=1e-3) @expand(product(minimizers+newton_minimizers)) @@ -129,7 +129,7 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), 1., + assert_allclose(energy.position.local_data, 1., rtol=1e-3, atol=1e-3) @expand(product(minimizers+slow_minimizers)) @@ -167,7 +167,7 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), 0., + assert_allclose(energy.position.local_data, 0., atol=1e-3) @expand(product(minimizers+newton_minimizers+slow_minimizers)) @@ -205,5 +205,4 @@ class Test_Minimizers(unittest.TestCase): raise SkipTest assert_equal(convergence, IC.CONVERGED) - assert_allclose(energy.position.to_global_data(), 0., - atol=1e-3) + assert_allclose(energy.position.local_data, 0., atol=1e-3) diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py index 12b803ad22cc7d1004b0c88d056e2843c29d2d13..8d4271737a0076383972660d716f2df30b130f2d 100644 --- a/test/test_operators/test_composed_operator.py +++ b/test/test_operators/test_composed_operator.py @@ -58,34 +58,34 @@ class ComposedOperator_Tests(unittest.TestCase): rand1 = ift.Field.from_random('normal', domain=(space1, space2)) tt1 = op.inverse_times(op.times(rand1)) - assert_allclose(tt1.to_global_data(), rand1.to_global_data()) + assert_allclose(tt1.local_data, rand1.local_data) @expand(product(spaces)) def test_sum(self, space): - op1 = ift.DiagonalOperator(ift.Field.full(space, 2.)) - op2 = ift.ScalingOperator(3., space) + op1 = ift.makeOp(ift.Field.full(space, 2.)) + op2 = 3. full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2 x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) - assert_allclose(res.to_global_data(), 11.) + assert_allclose(res.local_data, 11.) @expand(product(spaces)) def test_chain(self, space): - op1 = ift.DiagonalOperator(ift.Field.full(space, 2.)) - op2 = ift.ScalingOperator(3., space) + op1 = ift.makeOp(ift.Field.full(space, 2.)) + op2 = 3. full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2 x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) - assert_allclose(res.to_global_data(), 432.) + assert_allclose(res.local_data, 432.) @expand(product(spaces)) def test_mix(self, space): - op1 = ift.DiagonalOperator(ift.Field.full(space, 2.)) - op2 = ift.ScalingOperator(3., space) + op1 = ift.makeOp(ift.Field.full(space, 2.)) + op2 = 3. full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2 x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) - assert_allclose(res.to_global_data(), 42.) + assert_allclose(res.local_data, 42.) diff --git a/test/test_operators/test_diagonal_operator.py b/test/test_operators/test_diagonal_operator.py index 5569c761702383def26f8e37ba06020704c9d3de..6e7210c369d9200480238b2d02dac6f7e8ca852e 100644 --- a/test/test_operators/test_diagonal_operator.py +++ b/test/test_operators/test_diagonal_operator.py @@ -52,7 +52,7 @@ class DiagonalOperator_Tests(unittest.TestCase): diag = ift.Field.from_random('normal', domain=space) D = ift.DiagonalOperator(diag) tt1 = D.times(D.inverse_times(rand1)) - assert_allclose(rand1.to_global_data(), tt1.to_global_data()) + assert_allclose(rand1.local_data, tt1.local_data) @expand(product(spaces)) def test_times(self, space): @@ -91,4 +91,4 @@ class DiagonalOperator_Tests(unittest.TestCase): diag = ift.Field.from_random('normal', domain=space) D = ift.DiagonalOperator(diag) diag_op = D(ift.Field.full(space, 1.)) - assert_allclose(diag.to_global_data(), diag_op.to_global_data()) + assert_allclose(diag.local_data, diag_op.local_data)