Commit 67660e26 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'static_restructure' into 'NIFTy_4'

Static restructure

See merge request ift/NIFTy!259
parents 096f619e 36640cc0
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
__all__ = ["MultiDomain", "MultiField"]
__all__ = ["MultiDomain", "MultiField", "BlockDiagonalOperator"]
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
class BlockDiagonalOperator(EndomorphicOperator):
def __init__(self, operators):
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain.make(
{key: op.domain for key, op in self._operators.items()})
self._cap = self._all_ops
for op in self._operators.values():
self._cap &= op.capability
def domain(self):
return self._domain
def capability(self):
return self._cap
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField({key: op.apply(x[key], mode=mode)
for key, op in self._operators.items()})
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
def _combine_chain(self, op):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
return BlockDiagonalOperator(res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
[selfneg, opneg])
return BlockDiagonalOperator(res)
class MultiDomain(dict):
import collections
from ..domain_tuple import DomainTuple
__all = ["MultiDomain"]
class frozendict(collections.Mapping):
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
class MultiDomain(frozendict):
_domainCache = {}
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MultiDomain, self).__init__(domain)
def make(domain):
if isinstance(domain, MultiDomain):
return domain
if not isinstance(domain, dict):
raise TypeError("dict expected")
tmp = {}
for key, value in domain.items():
if not isinstance(key, str):
raise TypeError("keys must be strings")
tmp[key] = DomainTuple.make(value)
domain = frozendict(tmp)
obj = MultiDomain._domainCache.get(domain)
if obj is not None:
return obj
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
return obj
......@@ -44,7 +44,8 @@ class MultiField(object):
def domain(self):
return MultiDomain({key: val.domain for key, val in self._val.items()})
return MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
def dtype(self):
......@@ -57,6 +58,18 @@ class MultiField(object):
dtype[key], **kwargs)
for key in domain.keys()})
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
fill_value: float or complex or int
The value to fill the field with.
for val in self._val.values():
return self
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
......@@ -73,9 +86,22 @@ class MultiField(object):
return self
def locked(self):
return all(v.locked for v in self.values())
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
def locked_copy(self):
if self.locked:
return self
return MultiField({key: val.locked_copy()
for key, val in self.items()})
def empty_copy(self):
return MultiField({key: val.empty_copy() for key, val in self.items()})
def build_dtype(dtype, domain):
if isinstance(dtype, dict):
......@@ -85,22 +111,24 @@ class MultiField(object):
return {key: dtype for key in domain.keys()}
def zeros(domain, dtype=None):
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.zeros(dom, dtype=dtype[key])
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
def ones(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.ones(dom, dtype=dtype[key])
def full(domain, val):
return MultiField({key: Field.full(dom, val)
for key, dom in domain.items()})
def to_global_data(self):
return {key: val.to_global_data() for key, val in self._val.items()}
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
def from_global_data(domain, arr, sum_up=False):
return MultiField({key: Field.from_global_data(domain[key],
val, sum_up)
for key, val in arr.items()})
def norm(self):
""" Computes the L2-norm of the field values.
......@@ -78,6 +78,17 @@ class ChainOperator(LinearOperator):
ops = opsnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], BlockDiagonalOperator) and
isinstance(op, BlockDiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_chain(op)
ops = opsnew
return ops
......@@ -67,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode:
return self._op.apply(x, mode)
x0 = x*0.
x0 = x.empty_copy().fill(0.)
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation
......@@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()):
raise ValueError("requested operator mode is not supported")
def _check_input(self, x, mode):
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
if x.domain != self._dom(mode):
raise ValueError("The operator's and field's domains don't match.")
......@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator):
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return x.zeros_like(x)
return x.empty_copy().fill(0.)
if mode == self.TIMES:
return x*self._factor
......@@ -102,6 +102,28 @@ class SumOperator(LinearOperator):
ops = opsnew
neg = negnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], BlockDiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], BlockDiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
ops = opsnew
neg = negnew
return ops, neg
......@@ -16,19 +16,22 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import sys
import numpy as np
from .domains.power_space import PowerSpace
from .field import Field
from .multi.multi_field import MultiField
from .multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
from .domain_tuple import DomainTuple
from . import dobj, utilities
from .logger import logger
__all__ = ['PS_field',
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate']
def PS_field(pspace, func):
......@@ -161,3 +164,70 @@ def create_harmonic_smoothing_operator(domain, space, sigma):
kfunc = domain[space].get_fft_smoothing_kernel_function(sigma)
return DiagonalOperator(kfunc(domain[space].get_k_length_array()), domain,
def full(domain, val):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.full(domain, val)
return Field.full(domain, val)
def empty(domain, dtype):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.empty(domain, dtype)
return Field.empty(domain, dtype)
def from_random(random_type, domain, dtype=np.float64, **kwargs):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_random(random_type, domain, dtype, **kwargs)
return Field.from_random(random_type, domain, dtype, **kwargs)
def from_global_data(domain, arr, sum_up=False):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_global_data(domain, arr, sum_up)
return Field.from_global_data(domain, arr, sum_up)
def from_local_data(domain, arr):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_local_data(domain, arr)
return Field.from_local_data(domain, arr)
def makeDomain(domain):
if isinstance(domain, dict):
return MultiDomain.make(domain)
return DomainTuple.make(domain)
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
if isinstance(x, MultiField):
if out is not None:
if (not isinstance(out, MultiField) or
x._domain != out._domain):
raise ValueError("Bad 'out' argument")
for key, value in x.items():
func2(value, out=out[key])
return out
return MultiField({key: func2(val) for key, val in x.items()})
elif isinstance(x, Field):
fu = getattr(dobj, f)
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
return Field(domain=x._domain, val=fu(x.val))
return getattr(np, f)(x, out)
return func2
setattr(_current_module, f, func(f))
......@@ -50,7 +50,7 @@ class Energy_Tests(unittest.TestCase):
n = ift.Field.from_random(domain=space, random_type='normal')
s = ht(xi * A)
R = ift.ScalingOperator(10., space)
diag = ift.Field.ones(space)
diag = ift.full(space, 1.)
N = ift.DiagonalOperator(diag)
d = R(f(s)) + n
......@@ -18,7 +18,7 @@
import unittest
import numpy as np
from numpy.testing import assert_equal, assert_allclose
from numpy.testing import assert_equal, assert_allclose, assert_raises
from itertools import product
import nifty4 as ift
from test.common import expand
......@@ -124,24 +124,140 @@ class Test_Functionality(unittest.TestCase):
res = m.vdot(m, spaces=1)
assert_allclose(res.local_data, 37.5)
def test_lock(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.locked, False)
assert_equal(f1.locked, True)
with assert_raises(ValueError):
f1 += f1
assert_equal(f1.locked_copy() is f1, True)
def test_fill(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.fill(10).local_data, 10)
def test_dataconv(self):
s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
gd = np.arange(s1.shape[0])
assert_equal(ld, ift.from_local_data(s1, ld).local_data)
assert_equal(gd, ift.from_global_data(s1, gd).to_global_data())
def test_cast_domain(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((10,), distances=20.)
d = np.arange(s1.shape[0])
d2 = ift.from_global_data(s1, d).cast_domain(s2).to_global_data()
assert_equal(d, d2)
def test_empty_domain(self):
f = ift.Field((), 5)
assert_equal(f.to_global_data(), 5)
f = ift.Field(None, 5)
assert_equal(f.to_global_data(), 5)
assert_equal(f.empty_copy().domain, f.domain)
assert_equal(f.empty_copy().dtype, f.dtype)
assert_equal(f.copy().domain, f.domain)
assert_equal(f.copy().dtype, f.dtype)
assert_equal(f.copy().local_data, f.local_data)
assert_equal(f.copy() is f, False)
def test_trivialities(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.local_data, f1.real.local_data)
f1 = ift.Field(s1, 27.+3j)
assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.local_data, +f1.local_data)
assert_equal(f1.sum(), f1.sum(0))
f1 = ift.from_global_data(s1, np.arange(10))
assert_equal(f1.min(), 0)
assert_equal(f1.max(), 9)
assert_equal(, 0)
def test_weight(self):
s1 = ift.RGSpace((10,))
f = ift.Field(s1, 10.)
f2 = f.copy()
f.weight(1, out=f2)
assert_equal(f.weight(1).local_data, f2.local_data)
assert_equal(f.total_volume(), 1)
assert_equal(f.total_volume(0), 1)
assert_equal(f.total_volume((0,)), 1)
assert_equal(f.scalar_weight(), 0.1)
assert_equal(f.scalar_weight(0), 0.1)
assert_equal(f.scalar_weight((0,)), 0.1)
s1 = ift.GLSpace(10)
f = ift.Field(s1, 10.)
assert_equal(f.scalar_weight(), None)
assert_equal(f.scalar_weight(0), None)
assert_equal(f.scalar_weight((0,)), None)
@expand(product([ift.RGSpace(10), ift.GLSpace(10)],
[np.float64, np.complex128]))
def test_reduction(self, dom, dt):
s1 = ift.Field(dom, 1., dtype=dt)
assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14)
assert_allclose(s1.var(0), 0., atol=1e-14)
assert_allclose(s1.std(), 0., atol=1e-14)
assert_allclose(s1.std(0), 0., atol=1e-14)
def test_err(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((11,))
f1 = ift.Field(s1, 27)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1.val)
with assert_raises(TypeError):
f2 = ift.Field(s2, "xyz")
with assert_raises(TypeError):
if f1:
with assert_raises(TypeError):
f1.full((2, 4, 6))
with assert_raises(TypeError):
f2 = ift.Field(None, None)
with assert_raises(ValueError):
f2 = ift.Field(s1, None)
with assert_raises(ValueError):
with assert_raises(TypeError):
with assert_raises(ValueError):
f1.vdot(ift.Field(s2, 1.))
with assert_raises(TypeError):
with assert_raises(ValueError):
f1.copy_content_from(ift.Field(s2, 1.))
with assert_raises(TypeError):
ift.full(s1, [2, 3])
def test_stdfunc(self):
s = ift.RGSpace((200,))
f = ift.Field(s, 27)
assert_equal(f.local_data, 27)
assert_equal(f.shape, (200,))
fx = ift.Field.empty_like(f)
fx = ift.empty(f.domain, f.dtype)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
fx = ift.Field.zeros_like(f)
fx = ift.full(f.domain, 0)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
assert_equal(fx.local_data, 0)
fx = ift.Field.ones_like(f)
fx = ift.full(f.domain, 1)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
assert_equal(fx.local_data, 1)
fx = ift.Field.full_like(f, 67.)
fx = ift.full(f.domain, 67.)
assert_equal(f.shape, fx.shape)
assert_equal(fx.local_data, 67.)
f = ift.Field.from_random("normal", s)
......@@ -30,7 +30,7 @@ spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)]
minimizers = ['ift.VL_BFGS(IC)',
'ift.NonlinearCG(IC, "Polak-Ribiere")',
#'ift.NonlinearCG(IC, "Hestenes-Stiefel"),
# 'ift.NonlinearCG(IC, "Hestenes-Stiefel"),
'ift.NonlinearCG(IC, "Fletcher-Reeves")',
'ift.NonlinearCG(IC, "5.49")',
'ift.NewtonCG(xtol=1e-5, maxiter=1000)',
......@@ -53,7 +53,7 @@ class Test_Minimizers(unittest.TestCase):
covariance_diagonal = ift.Field.from_random(
'uniform', domain=space) + 0.5
covariance = ift.DiagonalOperator(covariance_diagonal)
required_result = ift.Field.ones(space, dtype=np.float64)
required_result = ift.full(space, 1.)
minimizer = eval(minimizer)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.