Commit 5da23204 authored by Martin Reinecke's avatar Martin Reinecke

more restructuring

parent 237b0a3f
......@@ -26,17 +26,14 @@ from .operators.domain_distributor import DomainDistributor
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
from .operators.fft_operator import FFTOperator
from .operators.field_adapter import FieldAdapter
from .operators.field_zero_padder import FieldZeroPadder
from .operators.hartley_operator import HartleyOperator
from .operators.harmonic_smoothing_operator import HarmonicSmoothingOperator
from .operators.geometry_remover import GeometryRemover
from .operators.harmonic_transform_operator import HarmonicTransformOperator
from .operators.inversion_enabler import InversionEnabler
from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator
from .operators.null_operator import NullOperator
from .operators.power_distributor import PowerDistributor
from .operators.qht_operator import QHTOperator
from .operators.sampling_enabler import SamplingEnabler
......@@ -45,8 +42,10 @@ from .operators.scaling_operator import ScalingOperator
from .operators.slope_operator import SlopeOperator
from .operators.smoothness_operator import SmoothnessOperator
from .operators.symmetrizing_operator import SymmetrizingOperator
from .operators.vdot_operator import VdotOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.simple_linear_operators import (
VdotOperator, SumReductionOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, BernoulliEnergy,
Hamiltonian, SampledKullbachLeiblerDivergence)
......
......@@ -42,7 +42,7 @@ class Linearization(object):
return self._metric
def __getitem__(self, name):
from .operators.field_adapter import FieldAdapter
from .operators.simple_linear_operators import FieldAdapter
return Linearization(self._val[name], FieldAdapter(self.domain, name))
def __neg__(self):
......@@ -99,7 +99,7 @@ class Linearization(object):
def vdot(self, other):
from .domain_tuple import DomainTuple
from .operators.vdot_operator import VdotOperator
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
......@@ -110,7 +110,7 @@ class Linearization(object):
VdotOperator(other._val)(self._jac))
def sum(self):
from .operators.vdot_operator import SumReductionOperator
from .operators.simple_linear_operators import SumReductionOperator
from .sugar import full
return Linearization(
Field(DomainTuple.scalar_domain(), self._val.sum()),
......@@ -143,5 +143,5 @@ class Linearization(object):
@staticmethod
def make_const(field):
from .operators.null_operator import NullOperator
from .operators.simple_linear_operators import NullOperator
return Linearization(field, NullOperator({}, field.domain))
......@@ -22,7 +22,7 @@ import numpy as np
from ..compat import *
from .linear_operator import LinearOperator
from .null_operator import NullOperator
from .simple_linear_operators import NullOperator
class ChainOperator(LinearOperator):
......
from __future__ import absolute_import, division, print_function
from ..compat import *
from .linear_operator import LinearOperator
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..field import Field
class FieldAdapter(LinearOperator):
def __init__(self, dom, name_dom):
self._domain = MultiDomain.make(dom)
self._name = name_dom
self._target = dom[name_dom]
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x[self._name]
values = tuple(Field.full(dom, 0.) if key != self._name else x
for key, dom in self._domain.items())
return MultiField(self._domain, values)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from .linear_operator import LinearOperator
class GeometryRemover(LinearOperator):
"""Operator which transforms between a structured and an unstructured
domain.
Parameters
----------
domain: Domain, tuple of Domain, or DomainTuple:
the full input domain of the operator.
Notes
-----
The operator will convert every sub-domain of its input domain to an
UnstructuredDomain with the same shape. No weighting by volume factors
is carried out.
"""
def __init__(self, domain):
super(GeometryRemover, self).__init__()
self._domain = DomainTuple.make(domain)
target_list = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(target_list)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.cast_domain(self._target)
return x.cast_domain(self._domain)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..multi_field import MultiField
from ..operators.linear_operator import LinearOperator
class NullOperator(LinearOperator):
"""Operator corresponding to a matrix of all zeros.
Parameters
----------
domain : DomainTuple or MultiDomain
input domain
target : DomainTuple or MultiDomain
output domain
"""
def __init__(self, domain, target):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
@staticmethod
def _nullfield(dom):
if isinstance(dom, DomainTuple):
return Field.full(dom, 0)
else:
return MultiField.full(dom, 0)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return self._nullfield(self._target)
return self._nullfield(self._domain)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -22,11 +22,13 @@ import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..domains.unstructured_domain import UnstructuredDomain
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
from ..sugar import full
from ..field import Field
from ..multi_field import MultiField
class VdotOperator(LinearOperator):
......@@ -113,3 +115,115 @@ class Realizer(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return x.real
class FieldAdapter(LinearOperator):
def __init__(self, dom, name_dom):
self._domain = MultiDomain.make(dom)
self._name = name_dom
self._target = dom[name_dom]
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x[self._name]
values = tuple(Field.full(dom, 0.) if key != self._name else x
for key, dom in self._domain.items())
return MultiField(self._domain, values)
class GeometryRemover(LinearOperator):
"""Operator which transforms between a structured and an unstructured
domain.
Parameters
----------
domain: Domain, tuple of Domain, or DomainTuple:
the full input domain of the operator.
Notes
-----
The operator will convert every sub-domain of its input domain to an
UnstructuredDomain with the same shape. No weighting by volume factors
is carried out.
"""
def __init__(self, domain):
super(GeometryRemover, self).__init__()
self._domain = DomainTuple.make(domain)
target_list = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(target_list)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.cast_domain(self._target)
return x.cast_domain(self._domain)
class NullOperator(LinearOperator):
"""Operator corresponding to a matrix of all zeros.
Parameters
----------
domain : DomainTuple or MultiDomain
input domain
target : DomainTuple or MultiDomain
output domain
"""
def __init__(self, domain, target):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
@staticmethod
def _nullfield(dom):
if isinstance(dom, DomainTuple):
return Field.full(dom, 0)
else:
return MultiField.full(dom, 0)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return self._nullfield(self._target)
return self._nullfield(self._domain)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment