Commit 67660e26 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'static_restructure' into 'NIFTy_4'

Static restructure

See merge request ift/NIFTy!259
parents 096f619e 36640cc0
...@@ -69,8 +69,8 @@ if __name__ == "__main__": ...@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data # Creating the mock data
d = noiseless_data + n d = noiseless_data + n
m0 = ift.Field.full(h_space, 1e-7) m0 = ift.full(h_space, 1e-7)
t0 = ift.Field.full(p_space, -4.) t0 = ift.full(p_space, -4.)
power0 = Distributor.times(ift.exp(0.5 * t0)) power0 = Distributor.times(ift.exp(0.5 * t0))
plotdict = {"colormap": "Planck-like"} plotdict = {"colormap": "Planck-like"}
......
...@@ -67,8 +67,8 @@ plt.legend() ...@@ -67,8 +67,8 @@ plt.legend()
plt.savefig('Krylov_samples_residuals.png') plt.savefig('Krylov_samples_residuals.png')
plt.close() plt.close()
D_hat_old = ift.Field.zeros(x_space).to_global_data() D_hat_old = ift.full(x_space, 0.).to_global_data()
D_hat_new = ift.Field.zeros(x_space).to_global_data() D_hat_new = ift.full(x_space, 0.).to_global_data()
for i in range(N_samps): for i in range(N_samps):
D_hat_old += sky(samps_old[i]).to_global_data()**2 D_hat_old += sky(samps_old[i]).to_global_data()**2
D_hat_new += sky(samps[i]).to_global_data()**2 D_hat_new += sky(samps[i]).to_global_data()**2
......
...@@ -69,8 +69,8 @@ if __name__ == "__main__": ...@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data # Creating the mock data
d = noiseless_data + n d = noiseless_data + n
m0 = ift.Field.full(h_space, 1e-7) m0 = ift.full(h_space, 1e-7)
t0 = ift.Field.full(p_space, -4.) t0 = ift.full(p_space, -4.)
power0 = Distributor.times(ift.exp(0.5 * t0)) power0 = Distributor.times(ift.exp(0.5 * t0))
IC1 = ift.GradientNormController(name="IC1", iteration_limit=100, IC1 = ift.GradientNormController(name="IC1", iteration_limit=100,
......
...@@ -36,7 +36,7 @@ if __name__ == "__main__": ...@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space = R.target d_space = R.target
p_op = ift.create_power_operator(h_space, p_spec) p_op = ift.create_power_operator(h_space, p_spec)
power = ift.sqrt(p_op(ift.Field.full(h_space, 1.))) power = ift.sqrt(p_op(ift.full(h_space, 1.)))
# Creating the mock data # Creating the mock data
true_sky = nonlinearity(HT(power*sh)) true_sky = nonlinearity(HT(power*sh))
...@@ -57,7 +57,7 @@ if __name__ == "__main__": ...@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter = ift.ConjugateGradient(controller=ICI) inverter = ift.ConjugateGradient(controller=ICI)
# initial guess # initial guess
m = ift.Field.full(h_space, 1e-7) m = ift.full(h_space, 1e-7)
map_energy = ift.library.NonlinearWienerFilterEnergy( map_energy = ift.library.NonlinearWienerFilterEnergy(
m, d, R, nonlinearity, HT, power, N, S, inverter=inverter) m, d, R, nonlinearity, HT, power, N, S, inverter=inverter)
......
...@@ -113,7 +113,7 @@ if __name__ == "__main__": ...@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain, np.random.poisson(lam.local_data).astype(np.float64)) d_domain, np.random.poisson(lam.local_data).astype(np.float64))
# initial guess # initial guess
psi0 = ift.Field.full(h_domain, 1e-7) psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h, energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter) inverter)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200, IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
......
...@@ -50,7 +50,7 @@ if __name__ == "__main__": ...@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1) controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
minimizer = ift.RelaxedNewton(controller=controller) minimizer = ift.RelaxedNewton(controller=controller)
m0 = ift.Field.zeros(h_space) m0 = ift.full(h_space, 0.)
# Initialize Wiener filter energy # Initialize Wiener filter energy
energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S, energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
......
...@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple ...@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from .operators import * from .operators import *
from .field import Field, sqrt, exp, log from .field import Field
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator StatCalculator
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from .random import Random from .random import Random
from mpi4py import MPI from mpi4py import MPI
import sys import sys
from functools import reduce
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
...@@ -145,20 +146,29 @@ class data_object(object): ...@@ -145,20 +146,29 @@ class data_object(object):
def sum(self, axis=None): def sum(self, axis=None):
return self._contraction_helper("sum", MPI.SUM, axis) return self._contraction_helper("sum", MPI.SUM, axis)
def prod(self, axis=None):
return self._contraction_helper("prod", MPI.PROD, axis)
def min(self, axis=None): def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis) return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None): def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis) return self._contraction_helper("max", MPI.MAX, axis)
def mean(self): def mean(self, axis=None):
return self.sum()/self.size if axis is None:
sz = self.size
else:
sz = reduce(lambda x, y: x*y, [self.shape[i] for i in axis])
return self.sum(axis)/sz
def std(self): def std(self, axis=None):
return np.sqrt(self.var()) return np.sqrt(self.var(axis))
# FIXME: to be improved! # FIXME: to be improved!
def var(self): def var(self, axis=None):
if axis is not None and len(axis) != len(self.shape):
raise ValueError("functionality not yet supported")
return (abs(self-self.mean())**2).mean() return (abs(self-self.mean())**2).mean()
def _binary_helper(self, other, op): def _binary_helper(self, other, op):
......
...@@ -34,7 +34,9 @@ class DomainTuple(object): ...@@ -34,7 +34,9 @@ class DomainTuple(object):
""" """
_tupleCache = {} _tupleCache = {}
def __init__(self, domain): def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
self._dom = self._parse_domain(domain) self._dom = self._parse_domain(domain)
self._axtuple = self._get_axes_tuple() self._axtuple = self._get_axes_tuple()
shape_tuple = tuple(sp.shape for sp in self._dom) shape_tuple = tuple(sp.shape for sp in self._dom)
...@@ -72,7 +74,7 @@ class DomainTuple(object): ...@@ -72,7 +74,7 @@ class DomainTuple(object):
obj = DomainTuple._tupleCache.get(domain) obj = DomainTuple._tupleCache.get(domain)
if obj is not None: if obj is not None:
return obj return obj
obj = DomainTuple(domain) obj = DomainTuple(domain, _callingfrommake=True)
DomainTuple._tupleCache[domain] = obj DomainTuple._tupleCache[domain] = obj
return obj return obj
......
...@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase ...@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class Domain(NiftyMetaBase()): class Domain(NiftyMetaBase()):
"""The abstract class repesenting a (structured or unstructured) domain. """The abstract class repesenting a (structured or unstructured) domain.
""" """
def __init__(self):
self._hash = None
@abc.abstractmethod @abc.abstractmethod
def __repr__(self): def __repr__(self):
...@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()): ...@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing. :attr:`._needed_for_hash` will be used for hashing.
""" """
result_hash = 0 if self._hash is None:
for key in self._needed_for_hash: h = 0
result_hash ^= hash(vars(self)[key]) for key in self._needed_for_hash:
return result_hash h ^= hash(vars(self)[key])
self._hash = h
return self._hash
def __eq__(self, x): def __eq__(self, x):
"""Checks whether two domains are equal. """Checks whether two domains are equal.
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from .structured_domain import StructuredDomain from .structured_domain import StructuredDomain
from ..field import Field, exp from ..field import Field
class LMSpace(StructuredDomain): class LMSpace(StructuredDomain):
...@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain): ...@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments" # cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al. # by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228 # http://arxiv.org/abs/astro-ph/0008228
from ..sugar import exp
res = x+1. res = x+1.
res *= x res *= x
res *= -0.5*sigma*sigma res *= -0.5*sigma*sigma
......
...@@ -21,7 +21,7 @@ from builtins import range ...@@ -21,7 +21,7 @@ from builtins import range
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from .structured_domain import StructuredDomain from .structured_domain import StructuredDomain
from ..field import Field, exp from ..field import Field
from .. import dobj from .. import dobj
...@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain): ...@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@staticmethod @staticmethod
def _kernel(x, sigma): def _kernel(x, sigma):
from ..sugar import exp
tmp = x*x tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma tmp *= -2.*np.pi*np.pi*sigma*sigma
exp(tmp, out=tmp) exp(tmp, out=tmp)
......
...@@ -17,17 +17,26 @@ ...@@ -17,17 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np import numpy as np
from ..sugar import from_random
from ..field import Field from ..field import Field
__all__ = ["consistency_check"] __all__ = ["consistency_check"]
def _assert_allclose(f1, f2, atol, rtol):
if isinstance(f1, Field):
return np.testing.assert_allclose(f1.local_data, f2.local_data,
atol=atol, rtol=rtol)
for key, val in f1.items():
_assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
f1 = Field.from_random("normal", op.domain, dtype=domain_dtype).lock() f1 = from_random("normal", op.domain, dtype=domain_dtype).lock()
f2 = Field.from_random("normal", op.target, dtype=target_dtype).lock() f2 = from_random("normal", op.target, dtype=target_dtype).lock()
res1 = f1.vdot(op.adjoint_times(f2).lock()) res1 = f1.vdot(op.adjoint_times(f2).lock())
res2 = op.times(f1).vdot(f2) res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
...@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
foo = Field.from_random("normal", op.target, dtype=target_dtype).lock() foo = from_random("normal", op.target, dtype=target_dtype).lock()
res = op(op.inverse_times(foo).lock()) res = op(op.inverse_times(foo).lock())
np.testing.assert_allclose(res.to_global_data(), res.to_global_data(), _assert_allclose(res, foo, atol=atol, rtol=rtol)
atol=atol, rtol=rtol)
foo = Field.from_random("normal", op.domain, dtype=domain_dtype).lock() foo = from_random("normal", op.domain, dtype=domain_dtype).lock()
res = op.inverse_times(op(foo).lock()) res = op.inverse_times(op(foo).lock())
np.testing.assert_allclose(res.to_global_data(), foo.to_global_data(), _assert_allclose(res, foo, atol=atol, rtol=rtol)
atol=atol, rtol=rtol)
def full_implementation(op, domain_dtype, target_dtype, atol, rtol): def full_implementation(op, domain_dtype, target_dtype, atol, rtol):
......
...@@ -106,62 +106,10 @@ class Field(object): ...@@ -106,62 +106,10 @@ class Field(object):
raise TypeError("val must be a scalar") raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val, dtype) return Field(DomainTuple.make(domain), val, dtype)
@staticmethod
def ones(domain, dtype=None):
return Field(DomainTuple.make(domain), 1., dtype)
@staticmethod
def zeros(domain, dtype=None):
return Field(DomainTuple.make(domain), 0., dtype)
@staticmethod @staticmethod
def empty(domain, dtype=None): def empty(domain, dtype=None):
return Field(DomainTuple.make(domain), None, dtype) return Field(DomainTuple.make(domain), None, dtype)
@staticmethod
def full_like(field, val, dtype=None):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
return Field.full(field._domain, val, dtype)
@staticmethod
def zeros_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.zeros(field._domain, dtype)
@staticmethod
def ones_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.ones(field._domain, dtype)
@staticmethod
def empty_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.empty(field._domain, dtype)
@staticmethod @staticmethod
def from_global_data(domain, arr, sum_up=False): def from_global_data(domain, arr, sum_up=False):
"""Returns a Field constructed from `domain` and `arr`. """Returns a Field constructed from `domain` and `arr`.
...@@ -287,6 +235,7 @@ class Field(object): ...@@ -287,6 +235,7 @@ class Field(object):
The value to fill the field with. The value to fill the field with.
""" """
self._val.fill(fill_value) self._val.fill(fill_value)
return self
def lock(self): def lock(self):
"""Write-protect the data content of `self`. """Write-protect the data content of `self`.
...@@ -370,6 +319,17 @@ class Field(object): ...@@ -370,6 +319,17 @@ class Field(object):
""" """
return Field(val=self, copy=True) return Field(val=self, copy=True)
def empty_copy(self):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return Field(self._domain, dtype=self.dtype)
def locked_copy(self): def locked_copy(self):
""" Returns a read-only version of the Field. """ Returns a read-only version of the Field.
...@@ -503,8 +463,8 @@ class Field(object): ...@@ -503,8 +463,8 @@ class Field(object):
or Field (for partial dot products) or Field (for partial dot products)
""" """
if not isinstance(x, Field): if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " + raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class") "the NIFTy field class")
if x._domain != self._domain: if x._domain != self._domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
...@@ -694,7 +654,8 @@ class Field(object): ...@@ -694,7 +654,8 @@ class Field(object):
if self.scalar_weight(spaces) is not None: if self.scalar_weight(spaces) is not None:
return self._contraction_helper('mean', spaces) return self._contraction_helper('mean', spaces)
# MR FIXME: not very efficient # MR FIXME: not very efficient
tmp = self.weight(1) # MR FIXME: do we need "spaces" here?
tmp = self.weight(1, spaces)
return tmp.sum(spaces)*(1./tmp.total_volume(spaces)) return tmp.sum(spaces)*(1./tmp.total_volume(spaces))
def var(self, spaces=None): def var(self, spaces=None):
...@@ -717,12 +678,10 @@ class Field(object): ...@@ -717,12 +678,10 @@ class Field(object):
# MR FIXME: not very efficient or accurate # MR FIXME: not very efficient or accurate
m1 = self.mean(spaces) m1 = self.mean(spaces)
if np.issubdtype(self.dtype, np.complexfloating): if np.issubdtype(self.dtype, np.complexfloating):
sq = abs(self)**2 sq = abs(self-m1)**2
m1 = abs(m1)**2
else: else:
sq = self**2 sq = (self-m1)**2
m1 **= 2 return sq.mean(spaces)
return sq.mean(spaces) - m1
def std(self, spaces=None): def std(self, spaces=None):
"""Determines the standard deviation over the sub-domains given by """Determines the standard deviation over the sub-domains given by
...@@ -742,6 +701,7 @@ class Field(object): ...@@ -742,6 +701,7 @@ class Field(object):
The result of the operation. If it is carried out over the entire The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field. domain, this is a scalar, otherwise a Field.
""" """
from .sugar import sqrt
if self.scalar_weight(spaces) is not None: if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces) return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces)) return sqrt(self.var(spaces))
...@@ -785,24 +745,3 @@ for op in ["__add__", "__radd__", "__iadd__", ...@@ -785,24 +745,3 @@ for op in ["__add__", "__radd__", "__iadd__",
return NotImplemented return NotImplemented
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
...@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller): ...@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
""" """
# RL FIXME: make consistent with complex numbers # RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j j = S.draw_sample(from_inverse=True) if j is None else j
energy = QuadraticEnergy(j*0., D_inv, j) energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)] y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy) status = controller.start(energy)
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import Field, exp from ..field import Field
from ..sugar import exp
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator from ..operators.diagonal_operator import DiagonalOperator
import numpy as np import numpy as np
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.</