Commit 9b53ee63 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into regridding_operator

parents f64f5e3d 6774b35a
......@@ -100,7 +100,7 @@ if __name__ == '__main__':
for _ in range(N_samples)]
KL = ift.SampledKullbachLeiblerDivergence(H, samples)
KL = ift.EnergyAdapter(position, KL, ic_cg)
KL = ift.EnergyAdapter(position, KL, ic_cg, constants=["xi"])
KL, convergence = minimizer(KL)
position = KL.position
......
......@@ -18,9 +18,9 @@
from __future__ import absolute_import, division, print_function
from . import utilities
from .compat import *
from .domains.domain import Domain
from . import utilities
class DomainTuple(object):
......
......@@ -52,6 +52,7 @@ class LogRGSpace(StructuredDomain):
def shape(self):
return self._shape
@property
def scalar_dvol(self):
return self._dvol
......
......@@ -17,10 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..sugar import from_random
from ..compat import *
from ..linearization import Linearization
from ..sugar import from_random
__all__ = ["check_value_gradient_consistency",
"check_value_gradient_metric_consistency"]
......
......@@ -17,10 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..sugar import from_random
from ..compat import *
from ..field import Field
from ..sugar import from_random
__all__ = ["consistency_check"]
......
......@@ -25,8 +25,8 @@ from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..sugar import makeOp, sqrt
from ..operators.operator import Operator
from ..sugar import makeOp, sqrt
def _ceps_kernel(dof_space, k, a, k0):
......
......@@ -21,11 +21,11 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..operators.distributors import PowerDistributor
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.distributors import PowerDistributor
from ..sugar import exp
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import exp
def CorrelatedField(s_space, amplitude_model):
......
......@@ -22,9 +22,9 @@ import numpy as np
from scipy.stats import invgamma, norm
from ..compat import *
from ..operators.operator import Operator
from ..linearization import Linearization
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
......
......@@ -19,6 +19,7 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..logger import logger
from .line_search_strong_wolfe import LineSearchStrongWolfe
......
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..linearization import Linearization
from ..multi_field import MultiField
import numpy as np
from ..minimization.energy import Energy
from ..operators.block_diagonal_operator import BlockDiagonalOperator
from ..operators.scaling_operator import ScalingOperator
class EnergyAdapter(Energy):
......@@ -25,14 +25,10 @@ class EnergyAdapter(Energy):
if len(self._constants) == 0:
tmp = self._op(Linearization.make_var(self._position))
else:
ctmp = MultiField.from_dict({key: val
for key, val in self._position.items()
if key in self._constants})
vtmp = MultiField.from_dict({key: val
for key, val in self._position.items()
if key not in self._constants})
lin = Linearization.make_var(vtmp) + Linearization.make_const(ctmp)
tmp = self._op(lin)
ops = [ScalingOperator(0. if key in self._constants else 1., dom)
for key, dom in self._position.domain.items()]
bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
tmp = self._op(Linearization(self._position, bdop))
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
if self._controller is not None:
......
......@@ -19,15 +19,15 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj
from ..compat import *
from ..field import Field
from ..multi_field import MultiField
from ..domain_tuple import DomainTuple
from ..logger import logger
from ..multi_field import MultiField
from ..utilities import iscomplextype
from .iteration_controllers import IterationController
from .minimizer import Minimizer
from ..utilities import iscomplextype
def _multiToArray(fld):
......
......@@ -20,10 +20,10 @@ from __future__ import absolute_import, division, print_function
import numpy as np
from . import utilities
from .compat import *
from .field import Field
from .multi_domain import MultiDomain
from . import utilities
class MultiField(object):
......
......@@ -19,9 +19,9 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from .endomorphic_operator import EndomorphicOperator
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator
class BlockDiagonalOperator(EndomorphicOperator):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
import itertools
import numpy as np
from .. import dobj, utilities
from ..compat import *
from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .. import dobj
from .linear_operator import LinearOperator
# MR FIXME: for even axis lengths, we probably should split the value at the
# highest frequency.
class CentralZeroPadder(LinearOperator):
"""Operator that enlarges a fields domain by adding zeros from the middle.
Parameters
---------
domain: Domain, tuple of Domains or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times"
new_shape: tuple
Shape of the target domain.
space: int, optional
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be an RGSpace.
"""
def __init__(self, domain, new_shape, space=0):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
# verify domains
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if dom.harmonic:
......@@ -29,12 +65,15 @@ class CentralZeroPadder(LinearOperator):
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
# make target space
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
# define the axes along which the input field is sliced
slicer = []
axes = self._target.axes[self._space]
for i in range(len(self._domain.shape)):
......@@ -56,7 +95,6 @@ class CentralZeroPadder(LinearOperator):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
axes = self._target.axes[self._space]
if dax in axes:
......@@ -65,10 +103,14 @@ class CentralZeroPadder(LinearOperator):
x = dobj.local_data(x)
if mode == self.TIMES:
# slice along each axis and copy the data to an
# array of zeros which has the shape of the target domain
y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
else:
# slice along each axis and copy the data to an array of zeros
# which has the shape of the input domain to remove excess zeros
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
......
......@@ -18,13 +18,11 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import utilities
from ..compat import *
from .diagonal_operator import DiagonalOperator
from .linear_operator import LinearOperator
from .. import utilities
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
from .simple_linear_operators import NullOperator
......@@ -44,8 +42,8 @@ class ChainOperator(LinearOperator):
@staticmethod
def simplify(ops):
# verify domains
for i in range(len(ops)-1):
if ops[i+1].target != ops[i].domain:
for i in range(len(ops) - 1):
if ops[i + 1].target != ops[i].domain:
raise ValueError("domain mismatch")
# unpack ChainOperators
opsnew = []
......@@ -78,9 +76,8 @@ class ChainOperator(LinearOperator):
# combine DiagonalOperators where possible
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], DiagonalOperator) and
isinstance(op, DiagonalOperator)):
if (len(opsnew) > 0 and isinstance(opsnew[-1], DiagonalOperator)
and isinstance(op, DiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_prod(op)
else:
opsnew.append(op)
......@@ -89,9 +86,9 @@ class ChainOperator(LinearOperator):
from .block_diagonal_operator import BlockDiagonalOperator
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], BlockDiagonalOperator) and
isinstance(op, BlockDiagonalOperator)):
if (len(opsnew) > 0
and isinstance(opsnew[-1], BlockDiagonalOperator)
and isinstance(op, BlockDiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_chain(op)
else:
opsnew.append(op)
......@@ -123,8 +120,8 @@ class ChainOperator(LinearOperator):
if trafo == 0:
return self
if trafo == ADJ or trafo == INV:
return self.make([op._flip_modes(trafo)
for op in reversed(self._ops)])
return self.make(
[op._flip_modes(trafo) for op in reversed(self._ops)])
if trafo == ADJ | INV:
return self.make([op._flip_modes(trafo) for op in self._ops])
raise ValueError("invalid operator transformation")
......@@ -136,6 +133,7 @@ class ChainOperator(LinearOperator):
x = op.apply(x, mode)
return x
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
# if len(self._ops) == 1:
......@@ -149,4 +147,4 @@ class ChainOperator(LinearOperator):
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n"+utilities.indent(subs)
return "ChainOperator:\n" + utilities.indent(subs)
......@@ -159,8 +159,8 @@ class DiagonalOperator(EndomorphicOperator):
def process_sample(self, samp, from_inverse):
if (self._complex or (self._diagmin < 0.) or
(self._diagmin == 0. and from_inverse)):
raise ValueError("operator not positive definite")
(self._diagmin == 0. and from_inverse)):
raise ValueError("operator not positive definite")
if from_inverse:
res = samp.local_data/np.sqrt(self._ldiag)
else:
......
......@@ -20,12 +20,11 @@ from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj
from .. import utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
class DomainDistributor(LinearOperator):
......
......@@ -18,15 +18,15 @@
from __future__ import absolute_import, division, print_function
from .. import utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..sugar import makeOp
from .operator import Operator
from .sandwich_operator import SandwichOperator
from .sampling_enabler import SamplingEnabler
from ..sugar import makeOp
from ..linearization import Linearization
from .. import utilities
from ..field import Field
from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator
......
......@@ -26,11 +26,29 @@ from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
class ExpTransform(LinearOperator):
"""
Transforms log-space to target.
This operator creates a log-space subject to the degrees of freedom and
and its target-domain.
Then transforms between this log-space and its target, which lives in
normal units.
E.g: A field in log-log-space can be transformed into log-norm-space,
that is the y-axis stays logarithmic, but the x-axis is transfromed.
Parameters
----------
target : domain, tuple of domains or DomainTuple
The full output domain
dof : int
The degrees of freedom of the log-domain, i.e. the number of bins.
"""
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
......
......@@ -2,13 +2,12 @@ from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj
from .. import dobj, utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
class FieldZeroPadder(LinearOperator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment