Commit 86b8fa17 authored by Philipp Arras's avatar Philipp Arras

Move NonlinearOperators to nifty

parent a940abf2
Pipeline #31011 passed with stages
in 1 minute and 23 seconds
......@@ -5,6 +5,7 @@ from .domains import *
from .domain_tuple import DomainTuple
from .field import Field
from .operators import *
from .nonlinear_operators import *
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator
from .minimization import *
......
from .nonlinear_operator import NonlinearOperator, LinearModel
from .constant import ConstantModel
from .local_nonlinearity import LocalModel
from .position import PositionModel
__all__ = ['NonlinearOperator', 'ConstantModel', 'LocalModel', 'PositionModel', 'LinearModel']
from ..operators import MultiSkyGradientOperator
from .nonlinear_operator import NonlinearOperator
class ConstantModel(NonlinearOperator):
def __init__(self, position, constant):
super(ConstantModel, self).__init__(position)
self._constant = constant
self._value = self._constant
self._gradient = MultiSkyGradientOperator({},
position.domain,
self.value.domain)
def at(self, position):
return self.__class__(position, self._constant)
from nifty4.sugar import makeOp
from .nonlinear_operator import NonlinearOperator
class LocalModel(NonlinearOperator):
def __init__(self, position, inp, nonlinearity):
"""
Computes nonlinearity(inp)
"""
super(LocalModel, self).__init__(position)
self._inp = inp.at(self.position)
self._nonlinearity = nonlinearity
self._value = nonlinearity(self._inp.value)
# Gradient
self._gradient = makeOp(self._nonlinearity.derivative(self._inp.value))*self._inp.gradient
def at(self, position):
return self.__class__(position, self._inp, self._nonlinearity)
import nifty4 as ift
from ..operators import LinearOperator
from .selection_operator import SelectionOperator
class NonlinearOperator(object):
def __init__(self, position):
self._position = position
def at(self, position):
raise NotImplementedError
@property
def position(self):
return self._position
@property
def value(self):
return self._value
@property
def gradient(self):
return self._gradient
def __getitem__(self, key):
sel = SelectionOperator(self.value.domain, key)
return LinearModel(self.position, self, sel)
# TODO Support addition and multiplication with fields
def __add__(self, other):
assert isinstance(other, NonlinearOperator)
return Add.make(self, other)
def __sub__(self, other):
assert isinstance(other, NonlinearOperator)
return Add.make(self, (-1) * other)
def __mul__(self, other):
if isinstance(other, (float, int)):
return ScalarMul(self._position, other, self)
if isinstance(other, NonlinearOperator):
return Mul.make(self, other)
raise NotImplementedError
def __rmul__(self, other):
if isinstance(other, (float, int)):
return self.__mul__(other)
raise NotImplementedError
def _joint_position(op1, op2):
a = op1.position._val
b = op2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
return ift.MultiField(ab)
class Mul(NonlinearOperator):
"""
Please note: If you multiply two operators which share some keys in the position
but have different values there, it is not guaranteed which value will be
used for the sum. You shouldn't do that anyways.
"""
def __init__(self, position, op1, op2):
super(Mul, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
self._value = self._op1.value * self._op2.value
self._gradient = ift.makeOp(self._op1.value) * self._op2.gradient + ift.makeOp(self._op2.value) * self._op1.gradient
@staticmethod
def make(op1, op2):
position = _joint_position(op1, op2)
return Mul(position, op1, op2)
def at(self, position):
return self.__class__(position, self._op1, self._op2)
class Add(NonlinearOperator):
"""
Please note: If you add two operators which share some keys in the position
but have different values there, it is not guaranteed which value will be
used for the sum. You shouldn't do that anyways.
"""
def __init__(self, position, op1, op2):
super(Add, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
self._value = self._op1.value + self._op2.value
self._gradient = self._op1.gradient + self._op2.gradient
@staticmethod
def make(op1, op2):
position = _joint_position(op1, op2)
return Add(position, op1, op2)
def at(self, position):
return self.__class__(position, self._op1, self._op2)
class ScalarMul(NonlinearOperator):
def __init__(self, position, factor, op):
super(ScalarMul, self).__init__(position)
assert isinstance(factor, (float, int))
assert isinstance(op, NonlinearOperator)
self._op = op.at(position)
self._factor = factor
self._value = self._factor * self._op.value
self._gradient = self._factor * self._op.gradient
def at(self, position):
return self.__class__(position, self._factor, self._op)
class LinearModel(NonlinearOperator):
def __init__(self, position, inp, lin_op):
"""
Computes lin_op(inp) where lin_op is a Linear Operator
"""
super(LinearModel, self).__init__(position)
if not isinstance(lin_op, LinearOperator):
raise TypeError("needs a LinearOperator as input")
self._inp = inp.at(position)
self._lin_op = lin_op
# FIXME This is a dirty hack!
if isinstance(self._lin_op, SelectionOperator):
self._lin_op = SelectionOperator(self._inp.value.domain,
self._lin_op._key)
self._value = self._lin_op(self._inp.value)
self._gradient = self._lin_op*self._inp.gradient
def at(self, position):
return self.__class__(position, self._inp, self._lin_op)
import nifty4 as ift
from .nonlinear_operator import NonlinearOperator
class PositionModel(NonlinearOperator):
"""
Returns the MultiField.
"""
def __init__(self, position):
super(PositionModel, self).__init__(position)
self._value = position
self._gradient = ift.ScalingOperator(1., position.domain)
def at(self, position):
return self.__class__(position)
from ..multi import MultiDomain, MultiField
from ..operators import LinearOperator
from ..sugar import full
class SelectionOperator(LinearOperator):
def __init__(self, domain, key):
if not isinstance(domain, MultiDomain):
raise TypeError("Domain must be a MultiDomain")
self._target = domain[key]
self._domain = domain
self._key = key
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x[self._key].copy()
else:
result = {}
for key, val in self.domain.items():
if key != self._key:
result[key] = full(val, 0.)
else:
result[key] = x.copy()
return MultiField(result)
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
from .harmonic_transform_operator import HarmonicTransformOperator
from .dof_distributor import DOFDistributor
from .endomorphic_operator import EndomorphicOperator
from .fft_operator import FFTOperator
from .fft_smoothing_operator import FFTSmoothingOperator
from .geometry_remover import GeometryRemover
from .harmonic_transform_operator import HarmonicTransformOperator
from .inversion_enabler import InversionEnabler
from .laplace_operator import LaplaceOperator
from .smoothness_operator import SmoothnessOperator
from .linear_operator import LinearOperator
from .power_distributor import PowerDistributor
from .inversion_enabler import InversionEnabler
from .sandwich_operator import SandwichOperator
from .sampling_enabler import SamplingEnabler
from .dof_distributor import DOFDistributor
from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
from .sky_gradient_operator import MultiSkyGradientOperator
from .smoothness_operator import SmoothnessOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"DiagonalOperator", "HarmonicTransformOperator", "FFTOperator",
"FFTSmoothingOperator", "GeometryRemover",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler", "SandwichOperator", "SamplingEnabler",
"DOFDistributor"]
"DOFDistributor", "MultiSkyGradientOperator"]
......@@ -18,7 +18,6 @@
import abc
from ..utilities import NiftyMetaBase
from ..field import Field
import numpy as np
......
from ..multi import MultiDomain, MultiField
from ..sugar import full
from .linear_operator import LinearOperator
class MultiSkyGradientOperator(LinearOperator):
def __init__(self, gradients, domain, target):
super(MultiSkyGradientOperator, self).__init__()
self._gradients = gradients
gradients_domain = MultiField(self._gradients).domain
self._domain = MultiDomain.make(domain)
# Check compatibility
# assert gradients_domain.items() <= self.domain.items()
# FIXME This is a python2 hack!
assert all(item in self.domain.items() for item in gradients_domain.items())
self._target = target
for grad in gradients.values():
if self._target != grad.target:
raise TypeError(
'All gradients have to have the same target domain')
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
@property
def gradients(self):
return self._gradients
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = None
for key, op in self._gradients.items():
if res is None:
res = op(x[key])
else:
res += op(x[key])
# Needed if gradients == {}
if res is None:
res = full(self.target, 0.)
assert res.domain == self.target
else:
grad_keys = self._gradients.keys()
res = {}
for dd in self.domain:
if dd in grad_keys:
res[dd] = self._gradients[dd].adjoint_times(x)
else:
res[dd] = full(self.domain[dd], 0.)
res = MultiField(res)
assert res.domain == self.domain
return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment