Commit adb19b35 authored by Martin Reinecke's avatar Martin Reinecke

some cosmetic fixes and unrelated tweaks

parent ecdcd3a6
Pipeline #30835 canceled with stages
......@@ -211,6 +211,7 @@ class data_object(object):
def fill(self, value):
self._data.fill(value)
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
......
......@@ -730,6 +730,7 @@ class Field(object):
return False
return (self._val == other._val).all()
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
......
......@@ -38,7 +38,7 @@ class NonlinearWienerFilterEnergy(Energy):
self.N = N
self.S = S
self.inverter = inverter
if sampling_inverter==None:
if sampling_inverter is None:
sampling_inverter = inverter
self.sampling_inverter = sampling_inverter
t1 = S.inverse_times(position)
......@@ -64,4 +64,5 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def curvature(self):
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, self.sampling_inverter)
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter,
self.sampling_inverter)
......@@ -45,9 +45,8 @@ def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None):
default: None
"""
M = SandwichOperator.make(R, N.inverse)
if sampling_inverter != None:
if sampling_inverter is not None:
op = SamplingEnabler(M, S.inverse, sampling_inverter)
else:
op = M + S.inverse
return InversionEnabler(op, inverter, S.inverse)
......@@ -20,7 +20,8 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from .wiener_filter_curvature import WienerFilterCurvature
def WienerFilterEnergy(position, d, R, N, S, inverter=None, sampling_inverter=None):
def WienerFilterEnergy(position, d, R, N, S, inverter=None,
sampling_inverter=None):
"""The Energy for the Wiener filter.
It covers the case of linear measurement with
......
......@@ -30,4 +30,5 @@ def _logger_init():
res.addHandler(logging.NullHandler())
return res
logger = _logger_init()
......@@ -21,7 +21,8 @@ from ..utilities import NiftyMetaBase
class LineSearch(NiftyMetaBase()):
"""Class for determining the optimal step size along some descent direction.
"""Class for determining the optimal step size along some descent
direction.
Parameters
----------
......
......@@ -25,7 +25,8 @@ from ..logger import logger
class LineSearchStrongWolfe(LineSearch):
"""Class for finding a step size that satisfies the strong Wolfe conditions.
"""Class for finding a step size that satisfies the strong Wolfe
conditions.
Algorithm contains two stages. It begins with a trial step length and
keeps increasing it until it finds an acceptable step length or an
......
......@@ -160,6 +160,7 @@ class MultiField(object):
return False
return True
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
......
......@@ -243,6 +243,7 @@ def makeOp(input):
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
......
......@@ -166,6 +166,8 @@ def nthreads():
import os
nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1"))
return nthreads._val
nthreads._val = None
# Optional extra arguments for the FFT calls
......
......@@ -155,9 +155,9 @@ class Test_Functionality(unittest.TestCase):
def test_empty_domain(self):
f = ift.Field((), 5)
assert_equal(f.to_global_data(), 5)
assert_equal(f.local_data, 5)
f = ift.Field(None, 5)
assert_equal(f.to_global_data(), 5)
assert_equal(f.local_data, 5)
assert_equal(f.empty_copy().domain, f.domain)
assert_equal(f.empty_copy().dtype, f.dtype)
assert_equal(f.copy().domain, f.domain)
......
......@@ -65,8 +65,8 @@ class Test_Minimizers(unittest.TestCase):
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.to_global_data(),
1./covariance_diagonal.to_global_data(),
assert_allclose(energy.position.local_data,
1./covariance_diagonal.local_data,
rtol=1e-3, atol=1e-3)
@expand(product(minimizers+newton_minimizers))
......@@ -129,7 +129,7 @@ class Test_Minimizers(unittest.TestCase):
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.to_global_data(), 1.,
assert_allclose(energy.position.local_data, 1.,
rtol=1e-3, atol=1e-3)
@expand(product(minimizers+slow_minimizers))
......@@ -167,7 +167,7 @@ class Test_Minimizers(unittest.TestCase):
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.to_global_data(), 0.,
assert_allclose(energy.position.local_data, 0.,
atol=1e-3)
@expand(product(minimizers+newton_minimizers+slow_minimizers))
......@@ -205,5 +205,4 @@ class Test_Minimizers(unittest.TestCase):
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.to_global_data(), 0.,
atol=1e-3)
assert_allclose(energy.position.local_data, 0., atol=1e-3)
......@@ -58,34 +58,34 @@ class ComposedOperator_Tests(unittest.TestCase):
rand1 = ift.Field.from_random('normal', domain=(space1, space2))
tt1 = op.inverse_times(op.times(rand1))
assert_allclose(tt1.to_global_data(), rand1.to_global_data())
assert_allclose(tt1.local_data, rand1.local_data)
@expand(product(spaces))
def test_sum(self, space):
op1 = ift.DiagonalOperator(ift.Field.full(space, 2.))
op2 = ift.ScalingOperator(3., space)
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
x = ift.Field.full(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
assert_allclose(res.to_global_data(), 11.)
assert_allclose(res.local_data, 11.)
@expand(product(spaces))
def test_chain(self, space):
op1 = ift.DiagonalOperator(ift.Field.full(space, 2.))
op2 = ift.ScalingOperator(3., space)
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2
x = ift.Field.full(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
assert_allclose(res.to_global_data(), 432.)
assert_allclose(res.local_data, 432.)
@expand(product(spaces))
def test_mix(self, space):
op1 = ift.DiagonalOperator(ift.Field.full(space, 2.))
op2 = ift.ScalingOperator(3., space)
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2
x = ift.Field.full(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
assert_allclose(res.to_global_data(), 42.)
assert_allclose(res.local_data, 42.)
......@@ -52,7 +52,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag)
tt1 = D.times(D.inverse_times(rand1))
assert_allclose(rand1.to_global_data(), tt1.to_global_data())
assert_allclose(rand1.local_data, tt1.local_data)
@expand(product(spaces))
def test_times(self, space):
......@@ -91,4 +91,4 @@ class DiagonalOperator_Tests(unittest.TestCase):
diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag)
diag_op = D(ift.Field.full(space, 1.))
assert_allclose(diag.to_global_data(), diag_op.to_global_data())
assert_allclose(diag.local_data, diag_op.local_data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment