Commit 1f9e0de6 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'work_on_fields' into 'NIFTy_5'

Work on fields

See merge request ift/nifty-dev!41
parents 0992e538 98b59a58
......@@ -22,9 +22,8 @@ if __name__ == '__main__':
ht = ift.HarmonicTransformOperator(harmonic_space, position_space)
power_space = A.value.domain[0]
power_distributor = ift.PowerDistributor(harmonic_space, power_space)
position = {}
position['xi'] = ift.Field.from_random('normal', harmonic_space)
position = ift.MultiField(position)
position = ift.MultiField.from_dict(
{'xi': ift.Field.from_random('normal', harmonic_space)})
xi = ift.Variable(position)['xi']
Amp = power_distributor(A)
......@@ -35,6 +34,7 @@ if __name__ == '__main__':
# apply some nonlinearity
signal = ift.PointwisePositiveTanh(correlated_field)
# Building the Line of Sight response
LOS_starts, LOS_ends = get_random_LOS(100)
R = ift.LOSResponse(position_space, starts=LOS_starts,
......
from __future__ import absolute_import, division, print_function
from builtins import (ascii, bytes, chr, dict, filter, hex, input,
map, next, oct, open, pow, range, round,
super, zip)
from functools import reduce
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .random import Random
from mpi4py import MPI
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
......
......@@ -16,6 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from .compat import *
try:
from mpi4py import MPI
if MPI.COMM_WORLD.Get_size() == 1:
......
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
from __future__ import absolute_import, division, print_function
from .compat import *
from .domains.domain import Domain
......@@ -138,26 +137,11 @@ class DomainTuple(object):
def __eq__(self, x):
if self is x:
return True
x = DomainTuple.make(x)
return self is x
return self is DomainTuple.make(x)
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
return self.__eq__(x)
def subsetOf(self, x):
return self.__eq__(x)
def unitedWith(self, x):
if self is x:
return self
x = DomainTuple.make(x)
if self is not x:
raise ValueError("domain mismatch")
return self
def __str__(self):
res = "DomainTuple, len: " + str(len(self))
for i in self:
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import abc
from ..utilities import NiftyMetaBase
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field
......
from functools import reduce
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..sugar import exp
import numpy as np
from .. import dobj
from ..field import Field
from .structured_domain import StructuredDomain
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from .. import dobj
......
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
from functools import reduce
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import abc
from .domain import Domain
import numpy as np
......
......@@ -16,8 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from .domain import Domain
from functools import reduce
class UnstructuredDomain(Domain):
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..library.gaussian_energy import GaussianEnergy
from ..minimization.energy import Energy
from ..models.variable import Variable
......
from builtins import *
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..utilities import memo, my_sum
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..sugar import from_random
from ..minimization.energy import Energy
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..sugar import from_random
from ..field import Field
......
......@@ -16,12 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from .compat import *
import numpy as np
from . import utilities
from .domain_tuple import DomainTuple
from functools import reduce
from . import dobj
......@@ -110,7 +109,7 @@ class Field(object):
@staticmethod
def from_local_data(domain, arr):
return Field(DomainTuple.make(domain),
dobj.from_local_data(domain.shape, arr))
dobj.from_local_data(domain.shape, arr))
def to_global_data(self):
"""Returns an array containing the full data of the field.
......@@ -215,14 +214,14 @@ class Field(object):
"""Field : The real part of the field"""
if not np.issubdtype(self.dtype, np.complexfloating):
return self
return Field(self._domain, self.val.real)
return Field(self._domain, self._val.real)
@property
def imag(self):
"""Field : The imaginary part of the field"""
if not np.issubdtype(self.dtype, np.complexfloating):
raise ValueError(".imag called on a non-complex Field")
return Field(self._domain, self.val.imag)
return Field(self._domain, self._val.imag)
def scalar_weight(self, spaces=None):
"""Returns the uniform volume element for a sub-domain of `self`.
......@@ -338,14 +337,14 @@ class Field(object):
raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class")
if x._domain != self._domain:
if x._domain is not self._domain:
raise ValueError("Domain mismatch")
ndom = len(self._domain)
spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
return dobj.vdot(self.val, x.val)
return dobj.vdot(self._val, x._val)
# If we arrive here, we have to do a partial dot product.
# For the moment, do this the explicit, non-optimized way
return (self.conjugate()*x).sum(spaces=spaces)
......@@ -378,7 +377,7 @@ class Field(object):
Field
The complex conjugated field.
"""
return Field(self._domain, self.val.conjugate())
return Field(self._domain, self._val.conjugate())
# ---General unary/contraction methods---
......@@ -386,14 +385,14 @@ class Field(object):
return self
def __neg__(self):
return Field(self._domain, -self.val)
return Field(self._domain, -self._val)
def __abs__(self):
return Field(self._domain, abs(self.val))
return Field(self._domain, abs(self._val))
def _contraction_helper(self, op, spaces):
if spaces is None:
return getattr(self.val, op)()
return getattr(self._val, op)()
spaces = utilities.parse_spaces(spaces, len(self._domain))
......@@ -403,7 +402,7 @@ class Field(object):
axes_list = reduce(lambda x, y: x+y, axes_list)
# perform the contraction on the data
data = getattr(self.val, op)(axis=axes_list)
data = getattr(self._val, op)(axis=axes_list)
# check if the result is scalar or if a result_field must be constr.
if np.isscalar(data):
......@@ -594,7 +593,7 @@ class Field(object):
def __str__(self):
return "nifty5.Field instance\n- domain = " + \
self._domain.__str__() + \
"\n- val = " + repr(self.val)
"\n- val = " + repr(self._val)
def isEquivalentTo(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
......@@ -603,7 +602,7 @@ class Field(object):
return True
if not isinstance(other, Field):
return False
if self._domain != other._domain:
if self._domain is not other._domain:
return False
return (self._val == other._val).all()
......@@ -625,14 +624,14 @@ for op in ["__add__", "__radd__",
def func2(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
tval = getattr(self._val, op)(other._val)
return Field(self._domain, tval)
if (np.isscalar(other) or
isinstance(other, (dobj.data_object, np.ndarray))):
tval = getattr(self.val, op)(other)
tval = getattr(self._val, op)(other)
return Field(self._domain, tval)
raise TypeError("should not arrive here")
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
......@@ -58,7 +77,7 @@ def make_amplitude_model(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
fields = {keys[0]: Field.from_random('normal', dof_space),
keys[1]: Field.from_random('normal', param_space)}
position = MultiField(fields)
position = MultiField.from_dict(fields)
dof_space = position[keys[0]].domain[0]
kern = lambda k: _ceps_kernel(dof_space, k, ceps_a, ceps_k)
......@@ -104,9 +123,7 @@ def create_cepstrum_amplitude_field(domain, cepstrum):
for i in range(dim):
ks = np.minimum(shape[i] - np.arange(shape[i]) +
1, np.arange(shape[i])) * dist[i]
fst_dims = (1,) * i
lst_dims = (1,) * (dim - i - 1)
q_array[i] += ks.reshape(fst_dims + (shape[i],) + lst_dims)
q_array[i] += ks.reshape((1,)*i + (shape[i],) + (1,)*(dim-i-1))
# Fill cepstrum field (all non-zero modes)
no_zero_modes = (slice(1, None),) * dim
......@@ -117,10 +134,9 @@ def create_cepstrum_amplitude_field(domain, cepstrum):
# Fill cepstrum field (zero-mode subspaces)
for i in range(dim):
# Prepare indices
fst_dims = (slice(None),) * i
lst_dims = (slice(None),) * (dim - i - 1)
sl = fst_dims + (slice(1, None),) + lst_dims
sl2 = fst_dims + (0,) + lst_dims
fst_dims = (slice(None),)*i
sl = fst_dims + (slice(1, None),)
sl2 = fst_dims + (0,)
# Do summation
cepstrum_field[sl2] = np.sum(cepstrum_field[sl], axis=i)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
def ApplyData(data, var, model_data):
# TODO This is rather confusing. Delete that eventually.
from ..operators.diagonal_operator import DiagonalOperator
......
......@@ -16,8 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from numpy import inf, isnan
from ..minimization.energy import Energy
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import log, makeOp
......@@ -35,10 +36,10 @@ class BernoulliEnergy(Energy):
self._d = d
p_val = self._p.value
self._value = -self._d.vdot(log(p_val)) - (1. - d).vdot(log(1.-p_val))
self._value = -self._d.vdot(log(p_val)) - (1.-d).vdot(log(1.-p_val))
if isnan(self._value):
self._value = inf
metric = makeOp(1./((p_val) * (1.-p_val)))
metric = makeOp(1. / (p_val * (1.-p_val)))
self._gradient = self._p.jacobian.adjoint_times(metric(p_val-d))
self._metric = SandwichOperator.make(self._p.jacobian, metric)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operators.fft_operator import FFTOperator
from ..field import Field
from ..multi.multi_field import MultiField
from ..models.local_nonlinearity import PointwiseExponential
from ..operators.power_distributor import PowerDistributor
from ..models.variable import Variable
from ..domain_tuple import DomainTuple
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_transform_operator \
import HarmonicTransformOperator
def make_correlated_field(s_space, amplitude_model):
'''
Method for construction of correlated fields
......@@ -8,21 +39,14 @@ def make_correlated_field(s_space, amplitude_model):
amplitude_model : model for correlation structure
'''
from ..operators.fft_operator import FFTOperator
from ..field import Field
from ..multi.multi_field import MultiField
from ..models.local_nonlinearity import PointwiseExponential
from ..operators.power_distributor import PowerDistributor
from ..models.variable import Variable
h_space = s_space.get_default_codomain()
ht = FFTOperator(h_space, s_space)
p_space = amplitude_model.value.domain[0]
power_distributor = PowerDistributor(h_space, p_space)
position = {}
position['xi'] = Field.from_random('normal', h_space)
position['tau'] = amplitude_model.position['tau']
position['phi'] = amplitude_model.position['phi']
position = MultiField(position)
position = MultiField.from_dict({
'xi': Field.from_random('normal', h_space),
'tau': amplitude_model.position['tau'],
'phi': amplitude_model.position['phi']})
xi = Variable(position)['xi']
A = power_distributor(amplitude_model)
......@@ -39,16 +63,6 @@ def make_mf_correlated_field(s_space_spatial, s_space_energy,
'''
Method for construction of correlated multi-frequency fields
'''
from ..operators.fft_operator import FFTOperator
from ..field import Field
from ..multi.multi_field import MultiField
from ..models.local_nonlinearity import PointwiseExponential
from ..operators.power_distributor import PowerDistributor
from ..models.variable import Variable
from ..domain_tuple import DomainTuple
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_transform_operator \
import HarmonicTransformOperator
h_space_spatial = s_space_spatial.get_default_codomain()
h_space_energy = s_space_energy.get_default_codomain()
h_space = DomainTuple.make((h_space_spatial, h_space_energy))
......@@ -70,7 +84,7 @@ def make_mf_correlated_field(s_space_spatial, s_space_energy,
a = a_spatial*a_energy
A = pd(a)
position = MultiField({'xi': Field.from_random('normal', h_space)})
position = MultiField.from_dict({'xi': Field.from_random('normal', h_space)})
xi = Variable(position)['xi']
correlated_field_h = A*xi
correlated_field = ht(correlated_field_h)
......
......@@ -16,11 +16,14 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *