Commit 087530b0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge NIFTy_4

parents 5429bb64 ec50fcc0
Pipeline #29620 passed with stages
in 4 minutes and 21 seconds
......@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler
from ..field import log
from ..sugar import log
class PoissonEnergy(Energy):
......@@ -46,7 +46,7 @@ class PoissonEnergy(Energy):
R1 = Instrument*Rho*ht
self._grad = (phipos + R1.adjoint_times((lam-d)/(lam+eps))).lock()
self._curv = Phi_h.inverse + SandwichOperator(R1, W)
self._curv = Phi_h.inverse + SandwichOperator.make(R1, W)
def at(self, position):
return self.__class__(position, self._d, self._Instrument,
......
......@@ -39,5 +39,5 @@ def WienerFilterCurvature(R, N, S, inverter):
inverter : Minimizer
The minimizer to use during numerical inversion
"""
op = SandwichOperator(R, N.inverse) + S.inverse
op = SandwichOperator.make(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.inverse)
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def _logger_init():
import logging
from . import dobj
......
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
__all__ = ["MultiDomain", "MultiField"]
__all__ = ["MultiDomain", "MultiField", "BlockDiagonalOperator"]
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
class BlockDiagonalOperator(EndomorphicOperator):
def __init__(self, operators):
"""
Parameters
----------
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain.make(
{key: op.domain for key, op in self._operators.items()})
self._cap = self._all_ops
for op in self._operators.values():
self._cap &= op.capability
@property
def domain(self):
return self._domain
@property
def capability(self):
return self._cap
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField({key: op.apply(x[key], mode=mode)
for key, op in self._operators.items()})
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
def _combine_chain(self, op):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
return BlockDiagonalOperator(res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
op._operators[key]],
[selfneg, opneg])
return BlockDiagonalOperator(res)
class MultiDomain(dict):
pass
import collections
from ..domain_tuple import DomainTuple
__all = ["MultiDomain"]
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
class MultiDomain(frozendict):
_domainCache = {}
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MultiDomain, self).__init__(domain)
@staticmethod
def make(domain):
if isinstance(domain, MultiDomain):
return domain
if not isinstance(domain, dict):
raise TypeError("dict expected")
tmp = {}
for key, value in domain.items():
if not isinstance(key, str):
raise TypeError("keys must be strings")
tmp[key] = DomainTuple.make(value)
domain = frozendict(tmp)
obj = MultiDomain._domainCache.get(domain)
if obj is not None:
return obj
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
return obj
......@@ -44,7 +44,8 @@ class MultiField(object):
@property
def domain(self):
return MultiDomain({key: val.domain for key, val in self._val.items()})
return MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
@property
def dtype(self):
......@@ -57,6 +58,18 @@ class MultiField(object):
dtype[key], **kwargs)
for key in domain.keys()})
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for val in self._val.values():
val.fill(fill_value)
return self
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
......@@ -73,9 +86,22 @@ class MultiField(object):
v.lock()
return self
@property
def locked(self):
return all(v.locked for v in self.values())
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
def locked_copy(self):
if self.locked:
return self
return MultiField({key: val.locked_copy()
for key, val in self.items()})
def empty_copy(self):
return MultiField({key: val.empty_copy() for key, val in self.items()})
@staticmethod
def build_dtype(dtype, domain):
if isinstance(dtype, dict):
......@@ -85,22 +111,24 @@ class MultiField(object):
return {key: dtype for key in domain.keys()}
@staticmethod
def zeros(domain, dtype=None):
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.zeros(dom, dtype=dtype[key])
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def ones(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.ones(dom, dtype=dtype[key])
def full(domain, val):
return MultiField({key: Field.full(dom, val)
for key, dom in domain.items()})
def to_global_data(self):
return {key: val.to_global_data() for key, val in self._val.items()}
@staticmethod
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
def from_global_data(domain, arr, sum_up=False):
return MultiField({key: Field.from_global_data(domain[key],
val, sum_up)
for key, val in arr.items()})
def norm(self):
""" Computes the L2-norm of the field values.
......
......@@ -64,7 +64,7 @@ class ChainOperator(LinearOperator):
opsnew[i] = opsnew[i]._scale(fct)
fct = 1.
break
if fct != 1:
if fct != 1 or len(opsnew) == 0:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(fct, lastdom))
ops = opsnew
......@@ -78,11 +78,24 @@ class ChainOperator(LinearOperator):
else:
opsnew.append(op)
ops = opsnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], BlockDiagonalOperator) and
isinstance(op, BlockDiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_chain(op)
else:
opsnew.append(op)
ops = opsnew
return ops
@staticmethod
def make(ops):
ops = tuple(ops)
if len(ops) == 0:
raise ValueError("ops is empty")
ops = ChainOperator.simplify(ops)
if len(ops) == 1:
return ops[0]
......
......@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
adom.check_codomain(target)
target.check_codomain(adom)
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
utilities.fft_prep()
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
return self._apply_cartesian(x, mode)
def _apply_cartesian(self, x, mode):
from pyfftw.interfaces.numpy_fft import fftn
axes = x.domain.axes[self._space]
tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val)
......@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = fftn(ldat2, axes=(1,))
ldat2 = utilities.my_fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
......
......@@ -67,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode:
return self._op.apply(x, mode)
x0 = x*0.
x0 = x.empty_copy().fill(0.)
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation
......
......@@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()):
raise ValueError("requested operator mode is not supported")
def _check_input(self, x, mode):
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
self._check_mode(mode)
if x.domain != self._dom(mode):
raise ValueError("The operator's and field's domains don't match.")
......@@ -16,31 +16,46 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from .diagonal_operator import DiagonalOperator
from .endomorphic_operator import EndomorphicOperator
from .scaling_operator import ScalingOperator
import numpy as np
class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`.
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
def __init__(self, bun, cheese=None):
def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SandwichOperator, self).__init__()
self._bun = bun
self._cheese = cheese
self._op = op
@staticmethod
def make(bun, cheese=None):
"""Build a SandwichOperator (or something simpler if possible)
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
if cheese is None:
self._cheese = ScalingOperator(1., bun.target)
self._op = bun.adjoint*bun
cheese = ScalingOperator(1., bun.target)
op = bun.adjoint*bun
else:
self._cheese = cheese
self._op = bun.adjoint*cheese*bun
op = bun.adjoint*cheese*bun
# if our sandwich is diagonal, we can return immediately
if isinstance(op, (ScalingOperator, DiagonalOperator)):
return op
return SandwichOperator(bun, cheese, op, _callingfrommake=True)
@property
def domain(self):
......@@ -54,8 +69,11 @@ class SandwichOperator(EndomorphicOperator):
return self._op.apply(x, mode)
def draw_sample(self, from_inverse=False, dtype=np.float64):
# Inverse samples from general sandwiches is not possible
if from_inverse:
raise NotImplementedError(
"cannot draw from inverse of this operator")
# Samples from general sandwiches
return self._bun.adjoint_times(
self._cheese.draw_sample(from_inverse, dtype))
......@@ -20,8 +20,8 @@ from __future__ import division
import numpy as np
from ..field import Field
from ..multi.multi_field import MultiField
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from ..domain_tuple import DomainTuple
class ScalingOperator(EndomorphicOperator):
......@@ -49,12 +49,13 @@ class ScalingOperator(EndomorphicOperator):
"""
def __init__(self, factor, domain):
from ..sugar import makeDomain
super(ScalingOperator, self).__init__()
if not np.isscalar(factor):
raise TypeError("Scalar required")
self._factor = factor
self._domain = DomainTuple.make(domain)
self._domain = makeDomain(domain)
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -62,7 +63,7 @@ class ScalingOperator(EndomorphicOperator):
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return x.zeros_like(x)
return x.empty_copy().fill(0.)
if mode == self.TIMES:
return x*self._factor
......
......@@ -102,12 +102,36 @@ class SumOperator(LinearOperator):
negnew.append(neg[i])
ops = opsnew
neg = negnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], BlockDiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], BlockDiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg
@staticmethod
def make(ops, neg):
ops = tuple(ops)
neg = tuple(neg)
if len(ops) == 0:
raise ValueError("ops is empty")
if len(ops) != len(neg):
raise ValueError("length mismatch between ops and neg")
ops, neg = SumOperator.simplify(ops, neg)
......
......@@ -16,23 +16,25 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import sys
import numpy as np
from .domains.power_space import PowerSpace
from .field import Field
from .multi.multi_field import MultiField
from .multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
from .domain_tuple import DomainTuple
from . import dobj, utilities
from .logger import logger
__all__ = ['PS_field',
'power_analyze',
'create_power_operator',
'create_harmonic_smoothing_operator',
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate',
'get_signal_variance']
def PS_field(pspace, func):
if not isinstance(pspace, PowerSpace):
raise TypeError
......@@ -53,15 +55,16 @@ def get_signal_variance(spec, space):
a method that takes one k-value and returns the power spectrum at that
location
space: PowerSpace or any harmonic Domain
If this function is given a harmonic domain, it creates the naturally binned
PowerSpace to that domain.
The field, for which the signal variance is then computed, is assumed to have
this PowerSpace as naturally binned PowerSpace
If this function is given a harmonic domain, it creates the naturally
binned PowerSpace to that domain.
The field, for which the signal variance is then computed, is assumed
to have this PowerSpace as naturally binned PowerSpace
"""
if space.harmonic:
space = PowerSpace(space)
if not isinstance(space, PowerSpace):
raise ValueError("space must be either a harmonic space or Power space.")
raise ValueError(
"space must be either a harmonic space or Power space.")
field = PS_field(space, spec)
dist = PowerDistributor(space.harmonic_partner, space)
k_field = dist(field)
......@@ -190,3 +193,70 @@ def create_harmonic_smoothing_operator(domain, space, sigma):
kfunc = domain[space].get_fft_smoothing_kernel_function(sigma)
return DiagonalOperator(kfunc(domain[space].get_k_length_array()), domain,
space)
def full(domain, val):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.full(domain, val)
return Field.full(domain, val)
def empty(domain, dtype):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.empty(domain, dtype)
return Field.empty(domain, dtype)
def from_random(random_type, domain, dtype=np.float64, **kwargs):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_random(random_type, domain, dtype, **kwargs)
return Field.from_random(random_type, domain, dtype, **kwargs)
def from_global_data(domain, arr, sum_up=False):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_global_data(domain, arr, sum_up)
return Field.from_global_data(domain, arr, sum_up)
def from_local_data(domain, arr):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_local_data(domain, arr)
return Field.from_local_data(domain, arr)
def makeDomain(domain):
if isinstance(domain, (MultiDomain, dict)):
return MultiDomain.make(domain)
return DomainTuple.make(domain)
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
if isinstance(x, MultiField):
if out is not None:
if (not isinstance(out, MultiField) or
x._domain != out._domain):
raise ValueError("Bad 'out' argument")
for key, value in x.items():
func2(value, out=out[key])
return out
return MultiField({key: func2(val) for key, val in x.items()})
elif isinstance(x, Field):
fu = getattr(dobj, f)
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain: