From a7bc5e41013a61c872f1ef5728ee46a1e90faa60 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sat, 19 May 2018 21:13:15 +0200 Subject: [PATCH] step 1 --- demos/critical_filtering.py | 4 +- demos/krylov_sampling.py | 4 +- demos/nonlinear_critical_filter.py | 4 +- demos/nonlinear_wiener_filter.py | 4 +- demos/poisson_demo.py | 2 +- demos/wiener_filter_via_hamiltonian.py | 2 +- nifty4/domain_tuple.py | 6 +- nifty4/domains/domain.py | 12 ++-- nifty4/extra/operator_tests.py | 10 +-- nifty4/field.py | 52 ---------------- nifty4/library/nonlinearities.py | 7 ++- nifty4/multi/multi_domain.py | 75 ++++++++++++++++++++++- nifty4/multi/multi_field.py | 15 ++--- nifty4/operators/linear_operator.py | 4 -- nifty4/operators/scaling_operator.py | 2 +- nifty4/sugar.py | 46 ++++++++++++-- test/test_energies/test_power.py | 2 +- test/test_field.py | 8 +-- test/test_minimization/test_minimizers.py | 2 +- 19 files changed, 157 insertions(+), 104 deletions(-) diff --git a/demos/critical_filtering.py b/demos/critical_filtering.py index 0011e23f0..d18f94af7 100644 --- a/demos/critical_filtering.py +++ b/demos/critical_filtering.py @@ -69,8 +69,8 @@ if __name__ == "__main__": # Creating the mock data d = noiseless_data + n - m0 = ift.Field.full(h_space, 1e-7) - t0 = ift.Field.full(p_space, -4.) + m0 = ift.full(h_space, 1e-7) + t0 = ift.full(p_space, -4.) power0 = Distributor.times(ift.exp(0.5 * t0)) plotdict = {"colormap": "Planck-like"} diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py index 7b316e4db..f3275f98f 100644 --- a/demos/krylov_sampling.py +++ b/demos/krylov_sampling.py @@ -67,8 +67,8 @@ plt.legend() plt.savefig('Krylov_samples_residuals.png') plt.close() -D_hat_old = ift.Field.zeros(x_space).to_global_data() -D_hat_new = ift.Field.zeros(x_space).to_global_data() +D_hat_old = ift.full(x_space, 0.).to_global_data() +D_hat_new = ift.full(x_space, 0.).to_global_data() for i in range(N_samps): D_hat_old += sky(samps_old[i]).to_global_data()**2 D_hat_new += sky(samps[i]).to_global_data()**2 diff --git a/demos/nonlinear_critical_filter.py b/demos/nonlinear_critical_filter.py index f1b84ab5c..8477acd9e 100644 --- a/demos/nonlinear_critical_filter.py +++ b/demos/nonlinear_critical_filter.py @@ -69,8 +69,8 @@ if __name__ == "__main__": # Creating the mock data d = noiseless_data + n - m0 = ift.Field.full(h_space, 1e-7) - t0 = ift.Field.full(p_space, -4.) + m0 = ift.full(h_space, 1e-7) + t0 = ift.full(p_space, -4.) power0 = Distributor.times(ift.exp(0.5 * t0)) IC1 = ift.GradientNormController(name="IC1", iteration_limit=100, diff --git a/demos/nonlinear_wiener_filter.py b/demos/nonlinear_wiener_filter.py index be5877157..d19e739ac 100644 --- a/demos/nonlinear_wiener_filter.py +++ b/demos/nonlinear_wiener_filter.py @@ -36,7 +36,7 @@ if __name__ == "__main__": d_space = R.target p_op = ift.create_power_operator(h_space, p_spec) - power = ift.sqrt(p_op(ift.Field.full(h_space, 1.))) + power = ift.sqrt(p_op(ift.full(h_space, 1.))) # Creating the mock data true_sky = nonlinearity(HT(power*sh)) @@ -57,7 +57,7 @@ if __name__ == "__main__": inverter = ift.ConjugateGradient(controller=ICI) # initial guess - m = ift.Field.full(h_space, 1e-7) + m = ift.full(h_space, 1e-7) map_energy = ift.library.NonlinearWienerFilterEnergy( m, d, R, nonlinearity, HT, power, N, S, inverter=inverter) diff --git a/demos/poisson_demo.py b/demos/poisson_demo.py index 6c62fa645..a065f589f 100644 --- a/demos/poisson_demo.py +++ b/demos/poisson_demo.py @@ -113,7 +113,7 @@ if __name__ == "__main__": d_domain, np.random.poisson(lam.local_data).astype(np.float64)) # initial guess - psi0 = ift.Field.full(h_domain, 1e-7) + psi0 = ift.full(h_domain, 1e-7) energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h, inverter) IC1 = ift.GradientNormController(name="IC1", iteration_limit=200, diff --git a/demos/wiener_filter_via_hamiltonian.py b/demos/wiener_filter_via_hamiltonian.py index 124a625a2..e83de5c74 100644 --- a/demos/wiener_filter_via_hamiltonian.py +++ b/demos/wiener_filter_via_hamiltonian.py @@ -50,7 +50,7 @@ if __name__ == "__main__": inverter = ift.ConjugateGradient(controller=ctrl) controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1) minimizer = ift.RelaxedNewton(controller=controller) - m0 = ift.Field.zeros(h_space) + m0 = ift.full(h_space, 0.) # Initialize Wiener filter energy energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S, diff --git a/nifty4/domain_tuple.py b/nifty4/domain_tuple.py index 21779f547..39eafa885 100644 --- a/nifty4/domain_tuple.py +++ b/nifty4/domain_tuple.py @@ -34,7 +34,9 @@ class DomainTuple(object): """ _tupleCache = {} - def __init__(self, domain): + def __init__(self, domain, _callingfrommake=False): + if not _callingfrommake: + raise NotImplementedError self._dom = self._parse_domain(domain) self._axtuple = self._get_axes_tuple() shape_tuple = tuple(sp.shape for sp in self._dom) @@ -72,7 +74,7 @@ class DomainTuple(object): obj = DomainTuple._tupleCache.get(domain) if obj is not None: return obj - obj = DomainTuple(domain) + obj = DomainTuple(domain, _callingfrommake=True) DomainTuple._tupleCache[domain] = obj return obj diff --git a/nifty4/domains/domain.py b/nifty4/domains/domain.py index 9f22a5baa..38db023bd 100644 --- a/nifty4/domains/domain.py +++ b/nifty4/domains/domain.py @@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase class Domain(NiftyMetaBase()): """The abstract class repesenting a (structured or unstructured) domain. """ + def __init__(self): + self._hash = None @abc.abstractmethod def __repr__(self): @@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()): Only members that are explicitly added to :attr:`._needed_for_hash` will be used for hashing. """ - result_hash = 0 - for key in self._needed_for_hash: - result_hash ^= hash(vars(self)[key]) - return result_hash + if self._hash is None: + h = 0 + for key in self._needed_for_hash: + h ^= hash(vars(self)[key]) + self._hash = h + return self._hash def __eq__(self, x): """Checks whether two domains are equal. diff --git a/nifty4/extra/operator_tests.py b/nifty4/extra/operator_tests.py index 258c2a877..15d2b9109 100644 --- a/nifty4/extra/operator_tests.py +++ b/nifty4/extra/operator_tests.py @@ -17,7 +17,7 @@ # and financially supported by the Studienstiftung des deutschen Volkes. import numpy as np -from ..field import Field +from ..sugar import from_random __all__ = ["consistency_check"] @@ -26,8 +26,8 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): needed_cap = op.TIMES | op.ADJOINT_TIMES if (op.capability & needed_cap) != needed_cap: return - f1 = Field.from_random("normal", op.domain, dtype=domain_dtype).lock() - f2 = Field.from_random("normal", op.target, dtype=target_dtype).lock() + f1 = from_random("normal", op.domain, dtype=domain_dtype).lock() + f2 = from_random("normal", op.target, dtype=target_dtype).lock() res1 = f1.vdot(op.adjoint_times(f2).lock()) res2 = op.times(f1).vdot(f2) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) @@ -37,12 +37,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): needed_cap = op.TIMES | op.INVERSE_TIMES if (op.capability & needed_cap) != needed_cap: return - foo = Field.from_random("normal", op.target, dtype=target_dtype).lock() + foo = from_random("normal", op.target, dtype=target_dtype).lock() res = op(op.inverse_times(foo).lock()) np.testing.assert_allclose(res.to_global_data(), res.to_global_data(), atol=atol, rtol=rtol) - foo = Field.from_random("normal", op.domain, dtype=domain_dtype).lock() + foo = from_random("normal", op.domain, dtype=domain_dtype).lock() res = op.inverse_times(op(foo).lock()) np.testing.assert_allclose(res.to_global_data(), foo.to_global_data(), atol=atol, rtol=rtol) diff --git a/nifty4/field.py b/nifty4/field.py index 289faf237..c30e01a13 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -106,62 +106,10 @@ class Field(object): raise TypeError("val must be a scalar") return Field(DomainTuple.make(domain), val, dtype) - @staticmethod - def ones(domain, dtype=None): - return Field(DomainTuple.make(domain), 1., dtype) - - @staticmethod - def zeros(domain, dtype=None): - return Field(DomainTuple.make(domain), 0., dtype) - @staticmethod def empty(domain, dtype=None): return Field(DomainTuple.make(domain), None, dtype) - @staticmethod - def full_like(field, val, dtype=None): - """Creates a Field from a template, filled with a constant value. - - Parameters - ---------- - field : Field - the template field, from which the domain is inferred - val : float/complex/int scalar - fill value. Data type of the field is inferred from val. - - Returns - ------- - Field - the newly created field - """ - if not isinstance(field, Field): - raise TypeError("field must be of Field type") - return Field.full(field._domain, val, dtype) - - @staticmethod - def zeros_like(field, dtype=None): - if not isinstance(field, Field): - raise TypeError("field must be of Field type") - if dtype is None: - dtype = field.dtype - return Field.zeros(field._domain, dtype) - - @staticmethod - def ones_like(field, dtype=None): - if not isinstance(field, Field): - raise TypeError("field must be of Field type") - if dtype is None: - dtype = field.dtype - return Field.ones(field._domain, dtype) - - @staticmethod - def empty_like(field, dtype=None): - if not isinstance(field, Field): - raise TypeError("field must be of Field type") - if dtype is None: - dtype = field.dtype - return Field.empty(field._domain, dtype) - @staticmethod def from_global_data(domain, arr, sum_up=False): """Returns a Field constructed from `domain` and `arr`. diff --git a/nifty4/library/nonlinearities.py b/nifty4/library/nonlinearities.py index 810054d0e..ab5a707a7 100644 --- a/nifty4/library/nonlinearities.py +++ b/nifty4/library/nonlinearities.py @@ -16,7 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..field import Field, exp, tanh +from ..field import exp, tanh +from ..sugar import full class Linear(object): @@ -24,10 +25,10 @@ class Linear(object): return x def derivative(self, x): - return Field.ones_like(x) + return full(x.domain, 1.) def hessian(self, x): - return Field.zeros_like(x) + return full(x.domain, 0.) class Exponential(object): diff --git a/nifty4/multi/multi_domain.py b/nifty4/multi/multi_domain.py index 27c5f841d..c77b642af 100644 --- a/nifty4/multi/multi_domain.py +++ b/nifty4/multi/multi_domain.py @@ -1,2 +1,73 @@ -class MultiDomain(dict): - pass +import collections +from ..domain_tuple import DomainTuple + +__all = ["MultiDomain"] + + +class frozendict(collections.Mapping): + """ + An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping` + interface. It can be used as a drop-in replacement for dictionaries where immutability is desired. + """ + + dict_cls = dict + + def __init__(self, *args, **kwargs): + self._dict = self.dict_cls(*args, **kwargs) + self._hash = None + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def copy(self, **add_or_replace): + return self.__class__(self, **add_or_replace) + + def __iter__(self): + return iter(self._dict) + + def __len__(self): + return len(self._dict) + + def __repr__(self): + return '<%s %r>' % (self.__class__.__name__, self._dict) + + def __hash__(self): + if self._hash is None: + h = 0 + for key, value in self._dict.items(): + h ^= hash((key, value)) + self._hash = h + return self._hash + + +class MultiDomain(frozendict): + _domainCache = {} + + def __init__(domain, _callingfrommake=False): + if not _callingfrommake: + raise NotImplementedError + super(MultiDomain, self).__init__(domain) + + @staticmethod + def make(domain): + if isinstance(domain, MultiDomain): + return domain + print type(domain) + if not isinstance(domain, dict): + raise TypeError("dict expected") + tmp = {} + for key, value in domain.items(): + if not isinstance(key, str): + raise TypeError("keys must be strings") + tmp[key] = DomainTuple.make(value) + domain = frozendict(tmp) + print tmp + obj = MultiDomain._domainCache.get(domain) + if obj is not None: + return obj + obj = MultiDomain(domain, _callingfrommake=True) + MultiDomain._domainCache[domain] = obj + return obj diff --git a/nifty4/multi/multi_field.py b/nifty4/multi/multi_field.py index 7bcf821bf..454e5c424 100644 --- a/nifty4/multi/multi_field.py +++ b/nifty4/multi/multi_field.py @@ -85,21 +85,14 @@ class MultiField(object): return {key: dtype for key in domain.keys()} @staticmethod - def zeros(domain, dtype=None): - dtype = MultiField.build_dtype(dtype, domain) - return MultiField({key: Field.zeros(dom, dtype=dtype[key]) - for key, dom in domain.items()}) - - @staticmethod - def ones(domain, dtype=None): + def empty(domain, dtype=None): dtype = MultiField.build_dtype(dtype, domain) - return MultiField({key: Field.ones(dom, dtype=dtype[key]) + return MultiField({key: Field.empty(dom, dtype=dtype[key]) for key, dom in domain.items()}) @staticmethod - def empty(domain, dtype=None): - dtype = MultiField.build_dtype(dtype, domain) - return MultiField({key: Field.empty(dom, dtype=dtype[key]) + def full(domain, val): + return MultiField({key: Field.full(dom, val) for key, dom in domain.items()}) def norm(self): diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index c7d5f2296..185c5049b 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()): raise ValueError("requested operator mode is not supported") def _check_input(self, x, mode): - # MR FIXME: temporary fix for working with MultiFields - #if not isinstance(x, Field): - # raise ValueError("supplied object is not a `Field`.") - self._check_mode(mode) if x.domain != self._dom(mode): raise ValueError("The operator's and field's domains don't match.") diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py index b9a26bfaf..20066cb7d 100644 --- a/nifty4/operators/scaling_operator.py +++ b/nifty4/operators/scaling_operator.py @@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator): if self._factor == 1.: return x.copy() if self._factor == 0.: - return x.zeros_like(x) + return x*0. if mode == self.TIMES: return x*self._factor diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 3e829e3ed..3e90bfd2b 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -19,16 +19,18 @@ import numpy as np from .domains.power_space import PowerSpace from .field import Field +from multi.multi_field import MultiField +from multi.multi_domain import MultiDomain from .operators.diagonal_operator import DiagonalOperator from .operators.power_distributor import PowerDistributor from .domain_tuple import DomainTuple from . import dobj, utilities from .logger import logger -__all__ = ['PS_field', - 'power_analyze', - 'create_power_operator', - 'create_harmonic_smoothing_operator'] +__all__ = ['PS_field', 'power_analyze', 'create_power_operator', + 'create_harmonic_smoothing_operator', 'from_random', + 'full', 'empty', 'from_global_data', 'from_local_data', + 'makeDomain'] def PS_field(pspace, func): @@ -161,3 +163,39 @@ def create_harmonic_smoothing_operator(domain, space, sigma): kfunc = domain[space].get_fft_smoothing_kernel_function(sigma) return DiagonalOperator(kfunc(domain[space].get_k_length_array()), domain, space) + + +def full(domain, val): + if isinstance(domain, (dict, MultiDomain)): + return MultiField.full(domain, val) + return Field.full(domain, val) + + +def empty(domain, dtype): + if isinstance(domain, (dict, MultiDomain)): + return MultiField.empty(domain, dtype) + return Field.empty(domain, dtype) + + +def from_random(random_type, domain, dtype=np.float64, **kwargs): + if isinstance(domain, (dict, MultiDomain)): + return MultiField.from_random(random_type, domain, dtype, **kwargs) + return Field.from_random(random_type, domain, dtype, **kwargs) + + +def from_global_data(domain, arr, sum_up=False): + if isinstance(domain, (dict, MultiDomain)): + return MultiField.from_global_data(domain, arr, sum_up) + return Field.from_global_data(domain, arr, sum_up) + + +def from_local_data(domain, arr): + if isinstance(domain, (dict, MultiDomain)): + return MultiField.from_local_data(domain, arr) + return Field.from_local_data(domain, arr) + + +def makeDomain(domain): + if isinstance(domain, dict): + return MultiDomain.make(domain) + return DomainTuple.make(domain) diff --git a/test/test_energies/test_power.py b/test/test_energies/test_power.py index 1062f65fc..a8c01f04b 100644 --- a/test/test_energies/test_power.py +++ b/test/test_energies/test_power.py @@ -50,7 +50,7 @@ class Energy_Tests(unittest.TestCase): n = ift.Field.from_random(domain=space, random_type='normal') s = ht(xi * A) R = ift.ScalingOperator(10., space) - diag = ift.Field.ones(space) + diag = ift.full(space, 1.) N = ift.DiagonalOperator(diag) d = R(f(s)) + n diff --git a/test/test_field.py b/test/test_field.py index 5515a1d41..1539b317b 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -130,18 +130,18 @@ class Test_Functionality(unittest.TestCase): assert_equal(f.local_data, 27) assert_equal(f.shape, (200,)) assert_equal(f.dtype, np.int) - fx = ift.Field.empty_like(f) + fx = ift.empty(f.domain, f.dtype) assert_equal(f.dtype, fx.dtype) assert_equal(f.shape, fx.shape) - fx = ift.Field.zeros_like(f) + fx = ift.full(f.domain, 0) assert_equal(f.dtype, fx.dtype) assert_equal(f.shape, fx.shape) assert_equal(fx.local_data, 0) - fx = ift.Field.ones_like(f) + fx = ift.full(f.domain, 1) assert_equal(f.dtype, fx.dtype) assert_equal(f.shape, fx.shape) assert_equal(fx.local_data, 1) - fx = ift.Field.full_like(f, 67.) + fx = ift.full(f.domain, 67.) assert_equal(f.shape, fx.shape) assert_equal(fx.local_data, 67.) f = ift.Field.from_random("normal", s) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 140479ae5..7e8d3ef31 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -53,7 +53,7 @@ class Test_Minimizers(unittest.TestCase): covariance_diagonal = ift.Field.from_random( 'uniform', domain=space) + 0.5 covariance = ift.DiagonalOperator(covariance_diagonal) - required_result = ift.Field.ones(space, dtype=np.float64) + required_result = ift.full(space, 1.) try: minimizer = eval(minimizer) -- GitLab