Commit 9f23622d authored by Martin Reinecke's avatar Martin Reinecke

cosmetics and one improvement(?)

parent e1157422
Pipeline #31036 passed with stages
in 1 minute and 24 seconds
......@@ -74,7 +74,8 @@ if __name__ == "__main__":
# Wiener filter
j = R.adjoint_times(N.inverse_times(data))
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
sampling_ctrl = ift.GradientNormController(name="sampling",tol_abs_gradnorm=1e2)
sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=1e2)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
......
......@@ -48,7 +48,8 @@ if __name__ == "__main__":
# Wiener filter
j = R.adjoint_times(N.inverse_times(data))
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=1e-2)
sampling_ctrl = ift.GradientNormController(name="sampling",tol_abs_gradnorm=2e1)
sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=2e1)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
......
......@@ -3,4 +3,5 @@ from .constant import ConstantModel
from .local_nonlinearity import LocalModel
from .position import PositionModel
__all__ = ['NonlinearOperator', 'ConstantModel', 'LocalModel', 'PositionModel', 'LinearModel']
__all__ = ['NonlinearOperator', 'ConstantModel', 'LocalModel', 'PositionModel',
'LinearModel']
......@@ -16,7 +16,8 @@ class LocalModel(NonlinearOperator):
self._value = nonlinearity(self._inp.value)
# Gradient
self._gradient = makeOp(self._nonlinearity.derivative(self._inp.value))*self._inp.gradient
self._gradient = makeOp(
self._nonlinearity.derivative(self._inp.value))*self._inp.gradient
def at(self, position):
return self.__class__(position, self._inp, self._nonlinearity)
......@@ -48,7 +48,6 @@ class NonlinearOperator(object):
raise NotImplementedError
def _joint_position(op1, op2):
a = op1.position._val
b = op2.position._val
......@@ -60,9 +59,9 @@ def _joint_position(op1, op2):
class Mul(NonlinearOperator):
"""
Please note: If you multiply two operators which share some keys in the position
but have different values there, it is not guaranteed which value will be
used for the sum. You shouldn't do that anyways.
Please note: If you multiply two operators which share some keys in the
position but have different values there, it is not guaranteed which value
will be used for the sum. You shouldn't do that anyways.
"""
def __init__(self, position, op1, op2):
super(Mul, self).__init__(position)
......@@ -71,7 +70,8 @@ class Mul(NonlinearOperator):
self._op2 = op2.at(position)
self._value = self._op1.value * self._op2.value
self._gradient = ift.makeOp(self._op1.value) * self._op2.gradient + ift.makeOp(self._op2.value) * self._op1.gradient
self._gradient = (ift.makeOp(self._op1.value) * self._op2.gradient +
ift.makeOp(self._op2.value) * self._op1.gradient)
@staticmethod
def make(op1, op2):
......
......@@ -11,9 +11,7 @@ class MultiSkyGradientOperator(LinearOperator):
self._domain = MultiDomain.make(domain)
# Check compatibility
# assert gradients_domain.items() <= self.domain.items()
# FIXME This is a python2 hack!
assert all(item in self.domain.items() for item in gradients_domain.items())
assert set(gradients_domain.items()) <= set(self.domain.items())
self._target = target
for grad in gradients.values():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment