Commit 5cf3ed2d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent a53fa55e
Pipeline #28458 passed with stages
in 2 minutes and 38 seconds
......@@ -66,6 +66,8 @@ class DomainTuple(object):
"""
if isinstance(domain, DomainTuple):
return domain
if isinstance(domain, dict):
return domain
domain = DomainTuple._parse_domain(domain)
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
......
from .multi_domain import MultiDomain
from .multi_field import MultiField
#from .multi_linear_operator import MultiLinearOperator
#from .multi_endomorphic_operator import MultiEndomorphicOperator
#from .multi_chain_operator import MultiChainOperator
#from .multi_sum_operator import MultiSumOperator
#from .multi_scaling_operator import MultiScalingOperator
__all__ = ["MultiDomain", "MultiField"]
#, "MultiLinearOperator",
# "MultiEndomorphicOperator", "MultiChainOperator",
# "MultiSumOperator", "MultiScalingOperator"]
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .multi_linear_operator import MultiLinearOperator
class MultiChainOperator(MultiLinearOperator):
"""Class representing chains of multi-operators."""
def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MultiChainOperator, self).__init__()
self._ops = ops
self._capability = self._all_ops
for op in ops:
self._capability &= op.capability
@staticmethod
def make(ops):
ops = tuple(ops)
if len(ops) == 1:
return ops[0]
return MultiChainOperator(ops, _callingfrommake=True)
@property
def domain(self):
return self._ops[-1].domain
@property
def target(self):
return self._ops[0].target
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if trafo == ADJ or trafo == INV:
return self.make([op._flip_modes(trafo)
for op in reversed(self._ops)])
if trafo == ADJ | INV:
return self.make([op._flip_modes(trafo) for op in self._ops])
raise ValueError("invalid operator transformation")
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_input(x, mode)
t_ops = self._ops if mode & self._backwards else reversed(self._ops)
for op in t_ops:
x = op.apply(x, mode)
return x
import numpy as np
from .multi_linear_operator import MultiLinearOperator
class MultiEndomorphicOperator(MultiLinearOperator):
"""
Class for multi endomorphic operators.
By definition, domain and target are the same in
EndomorphicOperator.
"""
@property
def target(self):
"""
MultiDomain : returns :attr:`domain`
Returns `self.domain`, because this is also the target domain
for endomorphic operators.
"""
return self.domain
def draw_sample(self, from_inverse=False, dtype=np.float64):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator. If from_inverse is True, the sample
is drawn from the inverse of the operator.
Returns
-------
MultiField
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
from ..operators.linear_operator import LinearOperator
from .multi_field import MultiField
class MultiLinearOperator(LinearOperator):
@staticmethod
def _toOperator(thing, dom):
#from .multi_scaling_operator import ScalingOperator
if isinstance(thing, MultiLinearOperator):
return thing
#if np.isscalar(thing):
# return MultiScalingOperator(thing, dom)
return NotImplemented
def __mul__(self, other):
from .multi_chain_operator import MultiChainOperator
other = self._toOperator(other, self.domain)
return MultiChainOperator.make([self, other])
def __rmul__(self, other):
from .multi_chain_operator import MultiChainOperator
other = self._toOperator(other, self.target)
return MultiChainOperator.make([other, self])
def __add__(self, other):
from .multi_sum_operator import MultiSumOperator
other = self._toOperator(other, self.domain)
return MultiSumOperator.make([self, other], [False, False])
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
from .multi_sum_operator import MultiSumOperator
other = self._toOperator(other, self.domain)
return MultiSumOperator.make([self, other], [False, True])
def __rsub__(self, other):
from .multi_sum_operator import MultiSumOperator
other = self._toOperator(other, self.domain)
return MultiSumOperator.make([other, self], [False, True])
def _check_input(self, x, mode):
if not isinstance(x, MultiField):
raise ValueError("supplied object is not a `MultiField`.")
self._check_mode(mode)
if x.domain != self._dom(mode):
raise ValueError("The operator's and field's domains don't match.")
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .multi_endomorphic_operator import MultiEndomorphicOperator
class MultiScalingOperator(MultiEndomorphicOperator):
"""Operator which multiplies a Multifield with a scalar.
The NIFTy MultiScalingOperator class is a subclass derived from the
EndomorphicOperator. It multiplies an input field with a given factor.
Parameters
----------
factor : scalar
The multiplication factor
domain : MultiDomain
The domain on which the Operator's input Field lives.
Notes
-----
Formally, this operator always supports all operation modes (times,
adjoint_times, inverse_times and inverse_adjoint_times), even if `factor`
is 0 or infinity. It is the user's responsibility to apply the operator
only in appropriate ways (e.g. call inverse_times only if `factor` is
nonzero).
This shortcoming will hopefully be fixed in the future.
"""
def __init__(self, factor, domain):
super(MultiScalingOperator, self).__init__()
if not np.isscalar(factor):
raise TypeError("Scalar required")
self._factor = factor
self._domain = domain
def apply(self, x, mode):
self._check_input(x, mode)
if self._factor == 1.:
return x.copy()
if mode == self.TIMES:
return x*self._factor
elif mode == self.ADJOINT_TIMES:
return x*np.conj(self._factor)
elif mode == self.INVERSE_TIMES:
return x*(1./self._factor)
else:
return x*(1./np.conj(self._factor))
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if trafo == ADJ and np.issubdtype(type(self._factor), np.floating):
return self
if trafo == ADJ:
return ScalingOperator(np.conj(self._factor), self._domain)
elif trafo == INV:
return ScalingOperator(1./self._factor, self._domain)
elif trafo == ADJ | INV:
return ScalingOperator(1./np.conj(self._factor), self._domain)
raise ValueError("invalid operator transformation")
@property
def domain(self):
return self._domain
@property
def capability(self):
return self._all_ops
def draw_sample(self, from_inverse=False, dtype=np.float64):
fct = self._factor
if fct.imag != 0. or fct.real <= 0.:
raise ValueError("operator not positive definite")
fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
return Field.from_random(
random_type="normal", domain=self._domain, std=fct, dtype=dtype)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .multi_linear_operator import MultiLinearOperator
import numpy as np
class MultiSumOperator(MultiLinearOperator):
"""Class representing sums of multi-operators."""
def __init__(self, ops, neg, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MultiSumOperator, self).__init__()
self._ops = ops
self._neg = neg
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
@staticmethod
def make(ops, neg):
ops = tuple(ops)
neg = tuple(neg)
if len(ops) != len(neg):
raise ValueError("length mismatch between ops and neg")
#ops, neg = MultiSumOperator.simplify(ops, neg)
if len(ops) == 1 and not neg[0]:
return ops[0]
return MultiSumOperator(ops, neg, _callingfrommake=True)
@property
def domain(self):
return self._ops[0].domain
@property
def target(self):
return self._ops[0].target
@property
def adjoint(self):
return self.make([op.adjoint for op in self._ops], self._neg)
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_input(x, mode)
for i, op in enumerate(self._ops):
if i == 0:
res = -op.apply(x, mode) if self._neg[i] else op.apply(x, mode)
else:
if self._neg[i]:
res -= op.apply(x, mode)
else:
res += op.apply(x, mode)
return res
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
res = self._ops[0].draw_sample(from_inverse, dtype)
for op in self._ops[1:]:
res += op.draw_sample(from_inverse, dtype)
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment