Commit 1c0333b2 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'NIFTy_4' into new_los

parents d697bffc 67660e26
Pipeline #29559 passed with stages
in 4 minutes
......@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d = noiseless_data + n
m0 = ift.Field.full(h_space, 1e-7)
t0 = ift.Field.full(p_space, -4.)
m0 = ift.full(h_space, 1e-7)
t0 = ift.full(p_space, -4.)
power0 = Distributor.times(ift.exp(0.5 * t0))
plotdict = {"colormap": "Planck-like"}
......
......@@ -31,7 +31,7 @@ d = R(s_x) + n
R_p = R * FFT * A
j = R_p.adjoint(N.inverse(d))
D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse
D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
N_samps = 200
......@@ -67,8 +67,8 @@ plt.legend()
plt.savefig('Krylov_samples_residuals.png')
plt.close()
D_hat_old = ift.Field.zeros(x_space).to_global_data()
D_hat_new = ift.Field.zeros(x_space).to_global_data()
D_hat_old = ift.full(x_space, 0.).to_global_data()
D_hat_new = ift.full(x_space, 0.).to_global_data()
for i in range(N_samps):
D_hat_old += sky(samps_old[i]).to_global_data()**2
D_hat_new += sky(samps[i]).to_global_data()**2
......
......@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d = noiseless_data + n
m0 = ift.Field.full(h_space, 1e-7)
t0 = ift.Field.full(p_space, -4.)
m0 = ift.full(h_space, 1e-7)
t0 = ift.full(p_space, -4.)
power0 = Distributor.times(ift.exp(0.5 * t0))
IC1 = ift.GradientNormController(name="IC1", iteration_limit=100,
......
......@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space = R.target
p_op = ift.create_power_operator(h_space, p_spec)
power = ift.sqrt(p_op(ift.Field.full(h_space, 1.)))
power = ift.sqrt(p_op(ift.full(h_space, 1.)))
# Creating the mock data
true_sky = nonlinearity(HT(power*sh))
......@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter = ift.ConjugateGradient(controller=ICI)
# initial guess
m = ift.Field.full(h_space, 1e-7)
m = ift.full(h_space, 1e-7)
map_energy = ift.library.NonlinearWienerFilterEnergy(
m, d, R, nonlinearity, HT, power, N, S, inverter=inverter)
......
......@@ -80,12 +80,12 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator(R, N.inverse) + Phi_h.inverse).inverse
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
m = HT(D(j))
# Uncertainty
D = ift.SandwichOperator(aHT, D) # real space propagator
D = ift.SandwichOperator.make(aHT, D) # real space propagator
Dhat = ift.probe_with_posterior_samples(D.inverse, None,
nprobes=nprobes)[1]
sig = ift.sqrt(Dhat)
......@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain, np.random.poisson(lam.local_data).astype(np.float64))
# initial guess
psi0 = ift.Field.full(h_domain, 1e-7)
psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
......
......@@ -51,7 +51,7 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator(R, N.inverse) + Sh.inverse).inverse
D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Sh)
m = D(j)
......
......@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter = ift.ConjugateGradient(controller=ctrl)
controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
minimizer = ift.RelaxedNewton(controller=controller)
m0 = ift.Field.zeros(h_space)
m0 = ift.full(h_space, 0.)
# Initialize Wiener filter energy
energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
......
......@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from .operators import *
from .field import Field, sqrt, exp, log
from .field import Field
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator
......
......@@ -20,6 +20,7 @@ import numpy as np
from .random import Random
from mpi4py import MPI
import sys
from functools import reduce
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -145,20 +146,29 @@ class data_object(object):
def sum(self, axis=None):
return self._contraction_helper("sum", MPI.SUM, axis)
def prod(self, axis=None):
return self._contraction_helper("prod", MPI.PROD, axis)
def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis)
def mean(self):
return self.sum()/self.size
def mean(self, axis=None):
if axis is None:
sz = self.size
else:
sz = reduce(lambda x, y: x*y, [self.shape[i] for i in axis])
return self.sum(axis)/sz
def std(self):
return np.sqrt(self.var())
def std(self, axis=None):
return np.sqrt(self.var(axis))
# FIXME: to be improved!
def var(self):
def var(self, axis=None):
if axis is not None and len(axis) != len(self.shape):
raise ValueError("functionality not yet supported")
return (abs(self-self.mean())**2).mean()
def _binary_helper(self, other, op):
......
......@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
_tupleCache = {}
def __init__(self, domain):
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
self._dom = self._parse_domain(domain)
self._axtuple = self._get_axes_tuple()
shape_tuple = tuple(sp.shape for sp in self._dom)
......@@ -72,7 +74,7 @@ class DomainTuple(object):
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
return obj
obj = DomainTuple(domain)
obj = DomainTuple(domain, _callingfrommake=True)
DomainTuple._tupleCache[domain] = obj
return obj
......
......@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class Domain(NiftyMetaBase()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def __init__(self):
self._hash = None
@abc.abstractmethod
def __repr__(self):
......@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
result_hash = 0
for key in self._needed_for_hash:
result_hash ^= hash(vars(self)[key])
return result_hash
if self._hash is None:
h = 0
for key in self._needed_for_hash:
h ^= hash(vars(self)[key])
self._hash = h
return self._hash
def __eq__(self, x):
"""Checks whether two domains are equal.
......
......@@ -19,7 +19,7 @@
from __future__ import division
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field, exp
from ..field import Field
class LMSpace(StructuredDomain):
......@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228
from ..sugar import exp
res = x+1.
res *= x
res *= -0.5*sigma*sigma
......
......@@ -21,7 +21,7 @@ from builtins import range
from functools import reduce
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field, exp
from ..field import Field
from .. import dobj
......@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@staticmethod
def _kernel(x, sigma):
from ..sugar import exp
tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma
exp(tmp, out=tmp)
......
......@@ -17,17 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..sugar import from_random
from ..field import Field
__all__ = ["consistency_check"]
def _assert_allclose(f1, f2, atol, rtol):
if isinstance(f1, Field):
return np.testing.assert_allclose(f1.local_data, f2.local_data,
atol=atol, rtol=rtol)
for key, val in f1.items():
_assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap:
return
f1 = Field.from_random("normal", op.domain, dtype=domain_dtype).lock()
f2 = Field.from_random("normal", op.target, dtype=target_dtype).lock()
f1 = from_random("normal", op.domain, dtype=domain_dtype).lock()
f2 = from_random("normal", op.target, dtype=target_dtype).lock()
res1 = f1.vdot(op.adjoint_times(f2).lock())
res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
......@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap:
return
foo = Field.from_random("normal", op.target, dtype=target_dtype).lock()
foo = from_random("normal", op.target, dtype=target_dtype).lock()
res = op(op.inverse_times(foo).lock())
np.testing.assert_allclose(res.to_global_data(), res.to_global_data(),
atol=atol, rtol=rtol)
_assert_allclose(res, foo, atol=atol, rtol=rtol)
foo = Field.from_random("normal", op.domain, dtype=domain_dtype).lock()
foo = from_random("normal", op.domain, dtype=domain_dtype).lock()
res = op.inverse_times(op(foo).lock())
np.testing.assert_allclose(res.to_global_data(), foo.to_global_data(),
atol=atol, rtol=rtol)
_assert_allclose(res, foo, atol=atol, rtol=rtol)
def full_implementation(op, domain_dtype, target_dtype, atol, rtol):
......
......@@ -106,62 +106,10 @@ class Field(object):
raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val, dtype)
@staticmethod
def ones(domain, dtype=None):
return Field(DomainTuple.make(domain), 1., dtype)
@staticmethod
def zeros(domain, dtype=None):
return Field(DomainTuple.make(domain), 0., dtype)
@staticmethod
def empty(domain, dtype=None):
return Field(DomainTuple.make(domain), None, dtype)
@staticmethod
def full_like(field, val, dtype=None):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
return Field.full(field._domain, val, dtype)
@staticmethod
def zeros_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.zeros(field._domain, dtype)
@staticmethod
def ones_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.ones(field._domain, dtype)
@staticmethod
def empty_like(field, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.empty(field._domain, dtype)
@staticmethod
def from_global_data(domain, arr, sum_up=False):
"""Returns a Field constructed from `domain` and `arr`.
......@@ -287,6 +235,7 @@ class Field(object):
The value to fill the field with.
"""
self._val.fill(fill_value)
return self
def lock(self):
"""Write-protect the data content of `self`.
......@@ -370,6 +319,17 @@ class Field(object):
"""
return Field(val=self, copy=True)
def empty_copy(self):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return Field(self._domain, dtype=self.dtype)
def locked_copy(self):
""" Returns a read-only version of the Field.
......@@ -503,8 +463,8 @@ class Field(object):
or Field (for partial dot products)
"""
if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class")
raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
......@@ -694,7 +654,8 @@ class Field(object):
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('mean', spaces)
# MR FIXME: not very efficient
tmp = self.weight(1)
# MR FIXME: do we need "spaces" here?
tmp = self.weight(1, spaces)
return tmp.sum(spaces)*(1./tmp.total_volume(spaces))
def var(self, spaces=None):
......@@ -717,12 +678,10 @@ class Field(object):
# MR FIXME: not very efficient or accurate
m1 = self.mean(spaces)
if np.issubdtype(self.dtype, np.complexfloating):
sq = abs(self)**2
m1 = abs(m1)**2
sq = abs(self-m1)**2
else:
sq = self**2
m1 **= 2
return sq.mean(spaces) - m1
sq = (self-m1)**2
return sq.mean(spaces)
def std(self, spaces=None):
"""Determines the standard deviation over the sub-domains given by
......@@ -742,6 +701,7 @@ class Field(object):
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
"""
from .sugar import sqrt
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces))
......@@ -785,24 +745,3 @@ for op in ["__add__", "__radd__", "__iadd__",
return NotImplemented
return func2
setattr(Field, op, func(op))
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
......@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
# RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j
energy = QuadraticEnergy(j*0., D_inv, j)
energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import Field, exp
from ..field import Field
from ..sugar import exp
from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
import numpy as np
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import exp
from ..sugar import exp
from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import Field, exp, tanh
from ..sugar import full, exp, tanh
class Linear(object):
......@@ -24,10 +24,10 @@ class Linear(object):
return x
def derivative(self, x):
return Field.ones_like(x)
return full(x.domain, 1.)
def hessian(self, x):
return Field.zeros_like(x)
return full(x.domain, 0.)
class Exponential(object):
......
......@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler
from ..field import log
from ..sugar import log
class PoissonEnergy(Energy):
......@@ -46,7 +46,7 @@ class PoissonEnergy(Energy):
R1 = Instrument*Rho*ht
self._grad = (phipos + R1.adjoint_times((lam-d)/(lam+eps))).lock()
self._curv = Phi_h.inverse + SandwichOperator(R1, W)
self._curv = Phi_h.inverse + SandwichOperator.make(R1, W)
def at(self, position):
return self.__class__(position, self._d, self._Instrument,
......
......@@ -39,5 +39,5 @@ def WienerFilterCurvature(R, N, S, inverter):
inverter : Minimizer
The minimizer to use during numerical inversion
"""
op = SandwichOperator(R, N.inverse) + S.inverse
op = SandwichOperator.make(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.inverse)
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def _logger_init():
import logging
from . import dobj
......
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
__all__ = ["MultiDomain", "MultiField"]
__all__ = ["MultiDomain", "MultiField", "BlockDiagonalOperator"]
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
class BlockDiagonalOperator(EndomorphicOperator):
def __init__(self, operators):
"""
Parameters
----------
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain.make(
{key: op.domain for key, op in self._operators.items()})
self._cap = self._all_ops
for op in self._operators.values():
self._cap &= op.capability
@property
def domain(self):
return self._domain
@property
def capability(self):
return self._cap
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField({key: op.apply(x[key], mode=mode)
for key, op in self._operators.items()})
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
def _combine_chain(self, op):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
return BlockDiagonalOperator(res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
op._operators[key]],
[selfneg, opneg])
return BlockDiagonalOperator(res)
class MultiDomain(dict):
pass
import collections
from ..domain_tuple import DomainTuple
__all = ["MultiDomain"]
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
class MultiDomain(frozendict):
_domainCache = {}
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MultiDomain, self).__init__(domain)
@staticmethod
def make(domain):
if isinstance(domain, MultiDomain):
return domain
if not isinstance(domain, dict):
raise TypeError("dict expected")
tmp = {}
for key, value in domain.items():
if not isinstance(key, str):
raise TypeError("keys must be strings")
tmp[key] = DomainTuple.make(value)
domain = frozendict(tmp)
obj = MultiDomain._domainCache.get(domain)
if obj is not None:
return obj
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
return obj
......@@ -44,7 +44,8 @@ class MultiField(object):
@property
def domain(self):
return MultiDomain({key: val.domain for key, val in self._val.items()})
return MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
@property
def dtype(self):
......@@ -57,6 +58,18 @@ class MultiField(object):
dtype[key], **kwargs)
for key in domain.keys()})
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for val in self._val.values():
val.fill(fill_value)
return self
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
......@@ -73,9 +86,22 @@ class MultiField(object):
v.lock()
return self
@property
def locked(self):