Commit a7bc5e41 authored by Martin Reinecke's avatar Martin Reinecke

step 1

parent 6fb90ba4
Pipeline #29487 failed with stages
in 3 minutes and 54 seconds
......@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d = noiseless_data + n
m0 = ift.Field.full(h_space, 1e-7)
t0 = ift.Field.full(p_space, -4.)
m0 = ift.full(h_space, 1e-7)
t0 = ift.full(p_space, -4.)
power0 = Distributor.times(ift.exp(0.5 * t0))
plotdict = {"colormap": "Planck-like"}
......
......@@ -67,8 +67,8 @@ plt.legend()
plt.savefig('Krylov_samples_residuals.png')
plt.close()
D_hat_old = ift.Field.zeros(x_space).to_global_data()
D_hat_new = ift.Field.zeros(x_space).to_global_data()
D_hat_old = ift.full(x_space, 0.).to_global_data()
D_hat_new = ift.full(x_space, 0.).to_global_data()
for i in range(N_samps):
D_hat_old += sky(samps_old[i]).to_global_data()**2
D_hat_new += sky(samps[i]).to_global_data()**2
......
......@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d = noiseless_data + n
m0 = ift.Field.full(h_space, 1e-7)
t0 = ift.Field.full(p_space, -4.)
m0 = ift.full(h_space, 1e-7)
t0 = ift.full(p_space, -4.)
power0 = Distributor.times(ift.exp(0.5 * t0))
IC1 = ift.GradientNormController(name="IC1", iteration_limit=100,
......
......@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space = R.target
p_op = ift.create_power_operator(h_space, p_spec)
power = ift.sqrt(p_op(ift.Field.full(h_space, 1.)))
power = ift.sqrt(p_op(ift.full(h_space, 1.)))
# Creating the mock data
true_sky = nonlinearity(HT(power*sh))
......@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter = ift.ConjugateGradient(controller=ICI)
# initial guess
m = ift.Field.full(h_space, 1e-7)
m = ift.full(h_space, 1e-7)
map_energy = ift.library.NonlinearWienerFilterEnergy(
m, d, R, nonlinearity, HT, power, N, S, inverter=inverter)
......
......@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain, np.random.poisson(lam.local_data).astype(np.float64))
# initial guess
psi0 = ift.Field.full(h_domain, 1e-7)
psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
......
......@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter = ift.ConjugateGradient(controller=ctrl)
controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
minimizer = ift.RelaxedNewton(controller=controller)
m0 = ift.Field.zeros(h_space)
m0 = ift.full(h_space, 0.)
# Initialize Wiener filter energy
energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
......
......@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
_tupleCache = {}
def __init__(self, domain):
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
self._dom = self._parse_domain(domain)
self._axtuple = self._get_axes_tuple()
shape_tuple = tuple(sp.shape for sp in self._dom)
......@@ -72,7 +74,7 @@ class DomainTuple(object):
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
return obj
obj = DomainTuple(domain)
obj = DomainTuple(domain, _callingfrommake=True)
DomainTuple._tupleCache[domain] = obj
return obj
......
......@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class Domain(NiftyMetaBase()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def __init__(self):
self._hash = None
@abc.abstractmethod
def __repr__(self):
......@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
result_hash = 0
for key in self._needed_for_hash:
result_hash ^= hash(vars(self)[key])
return result_hash
if self._hash is None:
h = 0
for key in self._needed_for_hash:
h ^= hash(vars(self)[key])
self._hash = h
return self._hash
def __eq__(self, x):
"""Checks whether two domains are equal.
......
......@@ -17,7 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
from ..sugar import from_random
__all__ = ["consistency_check"]
......@@ -26,8 +26,8 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap:
return
f1 = Field.from_random("normal", op.domain, dtype=domain_dtype).lock()
f2 = Field.from_random("normal", op.target, dtype=target_dtype).lock()
f1 = from_random("normal", op.domain, dtype=domain_dtype).lock()
f2 = from_random("normal", op.target, dtype=target_dtype).lock()
res1 = f1.vdot(op.adjoint_times(f2).lock())
res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
......@@ -37,12 +37,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap:
return
foo = Field.from_random("normal", op.target, dtype=target_dtype).lock()
foo = from_random("normal", op.target, dtype=target_dtype).lock()
res = op(op.inverse_times(foo).lock())
np.testing.assert_allclose(res.to_global_data(), res.to_global_data(),
atol=atol, rtol=rtol)
foo = Field.from_random("normal", op.domain, dtype=domain_dtype).lock()
foo = from_random("normal", op.domain, dtype=domain_dtype).lock()
res = op.inverse_times(op(foo).lock())
np.testing.assert_allclose(res.to_global_data(), foo.to_global_data(),
atol=atol, rtol=rtol)
......
......@@ -106,62 +106,10 @@ class Field(object):
raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val, dtype)
@staticmethod
def ones(domain, dtype=None):
return Field(DomainTuple.make(domain), 1., dtype)
@staticmethod
def zeros(domain, dtype=None):
return Field(DomainTuple.make(domain), 0., dtype)
@staticmethod
def empty(domain, dtype=None):
return Field(DomainTuple.make(domain), None, dtype)
@staticmethod
def full_like(field, val, dtype=None):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
return Field.full(field._domain, val, dtype)
@staticmethod
def zeros_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.zeros(field._domain, dtype)
@staticmethod
def ones_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.ones(field._domain, dtype)
@staticmethod
def empty_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.empty(field._domain, dtype)
@staticmethod
def from_global_data(domain, arr, sum_up=False):
"""Returns a Field constructed from `domain` and `arr`.
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import Field, exp, tanh
from ..field import exp, tanh
from ..sugar import full
class Linear(object):
......@@ -24,10 +25,10 @@ class Linear(object):
return x
def derivative(self, x):
return Field.ones_like(x)
return full(x.domain, 1.)
def hessian(self, x):
return Field.zeros_like(x)
return full(x.domain, 0.)
class Exponential(object):
......
class MultiDomain(dict):
pass
import collections
from ..domain_tuple import DomainTuple
__all = ["MultiDomain"]
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
class MultiDomain(frozendict):
_domainCache = {}
def __init__(domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MultiDomain, self).__init__(domain)
@staticmethod
def make(domain):
if isinstance(domain, MultiDomain):
return domain
print type(domain)
if not isinstance(domain, dict):
raise TypeError("dict expected")
tmp = {}
for key, value in domain.items():
if not isinstance(key, str):
raise TypeError("keys must be strings")
tmp[key] = DomainTuple.make(value)
domain = frozendict(tmp)
print tmp
obj = MultiDomain._domainCache.get(domain)
if obj is not None:
return obj
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
return obj
......@@ -85,21 +85,14 @@ class MultiField(object):
return {key: dtype for key in domain.keys()}
@staticmethod
def zeros(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.zeros(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def ones(domain, dtype=None):
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.ones(dom, dtype=dtype[key])
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.empty(dom, dtype=dtype[key])
def full(domain, val):
return MultiField({key: Field.full(dom, val)
for key, dom in domain.items()})
def norm(self):
......
......@@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()):
raise ValueError("requested operator mode is not supported")
def _check_input(self, x, mode):
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
self._check_mode(mode)
if x.domain != self._dom(mode):
raise ValueError("The operator's and field's domains don't match.")
......@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator):
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return x.zeros_like(x)
return x*0.
if mode == self.TIMES:
return x*self._factor
......
......@@ -19,16 +19,18 @@
import numpy as np
from .domains.power_space import PowerSpace
from .field import Field
from multi.multi_field import MultiField
from multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
from .domain_tuple import DomainTuple
from . import dobj, utilities
from .logger import logger
__all__ = ['PS_field',
'power_analyze',
'create_power_operator',
'create_harmonic_smoothing_operator']
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain']
def PS_field(pspace, func):
......@@ -161,3 +163,39 @@ def create_harmonic_smoothing_operator(domain, space, sigma):
kfunc = domain[space].get_fft_smoothing_kernel_function(sigma)
return DiagonalOperator(kfunc(domain[space].get_k_length_array()), domain,
space)
def full(domain, val):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.full(domain, val)
return Field.full(domain, val)
def empty(domain, dtype):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.empty(domain, dtype)
return Field.empty(domain, dtype)
def from_random(random_type, domain, dtype=np.float64, **kwargs):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_random(random_type, domain, dtype, **kwargs)
return Field.from_random(random_type, domain, dtype, **kwargs)
def from_global_data(domain, arr, sum_up=False):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_global_data(domain, arr, sum_up)
return Field.from_global_data(domain, arr, sum_up)
def from_local_data(domain, arr):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_local_data(domain, arr)
return Field.from_local_data(domain, arr)
def makeDomain(domain):
if isinstance(domain, dict):
return MultiDomain.make(domain)
return DomainTuple.make(domain)
......@@ -50,7 +50,7 @@ class Energy_Tests(unittest.TestCase):
n = ift.Field.from_random(domain=space, random_type='normal')
s = ht(xi * A)
R = ift.ScalingOperator(10., space)
diag = ift.Field.ones(space)
diag = ift.full(space, 1.)
N = ift.DiagonalOperator(diag)
d = R(f(s)) + n
......
......@@ -130,18 +130,18 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.local_data, 27)
assert_equal(f.shape, (200,))
assert_equal(f.dtype, np.int)
fx = ift.Field.empty_like(f)
fx = ift.empty(f.domain, f.dtype)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
fx = ift.Field.zeros_like(f)
fx = ift.full(f.domain, 0)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
assert_equal(fx.local_data, 0)
fx = ift.Field.ones_like(f)
fx = ift.full(f.domain, 1)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
assert_equal(fx.local_data, 1)
fx = ift.Field.full_like(f, 67.)
fx = ift.full(f.domain, 67.)
assert_equal(f.shape, fx.shape)
assert_equal(fx.local_data, 67.)
f = ift.Field.from_random("normal", s)
......
......@@ -53,7 +53,7 @@ class Test_Minimizers(unittest.TestCase):
covariance_diagonal = ift.Field.from_random(
'uniform', domain=space) + 0.5
covariance = ift.DiagonalOperator(covariance_diagonal)
required_result = ift.Field.ones(space, dtype=np.float64)
required_result = ift.full(space, 1.)
try:
minimizer = eval(minimizer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment