Commit 901bc10c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

revamp combined operators

parent b2330148
......@@ -53,7 +53,7 @@ if __name__ == '__main__':
mask = np.ones(position_space.shape)
elif mode == 1:
# Two dimensional regular grid with chess mask
position_space = ift.RGSpace([128, 128])
position_space = ift.RGSpace([4096, 4096])
mask = make_chess_mask(position_space)
else:
# Sphere with half of its locations randomly masked
......@@ -95,7 +95,7 @@ if __name__ == '__main__':
j = R.adjoint_times(N.inverse_times(data))
D_inv = R.adjoint(N.inverse(R)) + S.inverse
# Make it invertible
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3, name="blah")
D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse
# WIENER FILTER
......
......@@ -77,8 +77,8 @@ if __name__ == '__main__':
ic_newton = ift.GradientNormController(name='Newton', iteration_limit=100)
minimizer = ift.RelaxedNewton(ic_newton)
# minimizer = ift.VL_BFGS(ic_newton)
# minimizer = ift.NewtonCG(1e-10, 100, True)
# minimizer = ift.L_BFGS_B(1e-10, 1e-5, 100, 10, True)
# minimizer = ift.NewtonCG(xtol=1e-10, maxiter=100, disp=True)
# minimizer = ift.L_BFGS_B(ftol=1e-10, gtol=1e-5, maxiter=100, maxcor=20, disp=True)
# build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling)
......@@ -95,6 +95,7 @@ if __name__ == '__main__':
N_samples = 20
for i in range(2):
metric = H(ift.Linearization.make_var(position)).metric
print(metric)
samples = [metric.draw_sample(from_inverse=True)
for _ in range(N_samples)]
......
......@@ -614,7 +614,10 @@ class Field(object):
return self
def unite(self, other):
return self + other
return self+other
def flexible_addsub(self, other, neg):
return self-other if neg else self+other
def positive_tanh(self):
return 0.5*(1.+self.tanh())
......
......@@ -61,22 +61,28 @@ class Linearization(object):
def real(self):
return Linearization(self._val.real, self._jac.real)
def __add__(self, other):
def _myadd(self, other, neg):
if isinstance(other, Linearization):
met = None
if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, False)
met = self._metric._myadd(other._metric, neg)
return Linearization(
self._val.unite(other._val),
self._jac._myadd(other._jac, False), met)
self._val.flexible_addsub(other._val, neg),
self._jac._myadd(other._jac, neg), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac, self._metric)
if neg:
return Linearization(self._val-other, self._jac, self._metric)
else:
return Linearization(self._val+other, self._jac, self._metric)
def __add__(self, other):
return self._myadd(other, False)
def __radd__(self, other):
return self.__add__(other)
return self._myadd(other, False)
def __sub__(self, other):
return self.__add__(-other)
return self._myadd(other, True)
def __rsub__(self, other):
return (-self).__add__(other)
......
......@@ -213,6 +213,17 @@ class MultiField(object):
res[key] = res[key]+val if key in res else val
return MultiField.from_dict(res)
def flexible_addsub(self, other, neg):
if self._domain is other._domain:
return self-other if neg else self+other
res = self.to_dict()
for key, val in other.items():
if key in res:
res[key] = res[key]-val if neg else res[key]+val
else:
res[key] = -val if neg else val
return MultiField.from_dict(res)
def _binary_op(self, other, op):
f = getattr(Field, op)
if isinstance(other, MultiField):
......
......@@ -22,6 +22,9 @@ import numpy as np
from ..compat import *
from .linear_operator import LinearOperator
from .. import utilities
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
from .simple_linear_operators import NullOperator
......@@ -40,24 +43,19 @@ class ChainOperator(LinearOperator):
@staticmethod
def simplify(ops):
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
# verify domains
for i in range(len(ops)-1):
if ops[i+1].target != ops[i].domain:
raise ValueError("domain mismatch")
# Step 2: unpack ChainOperators
# unpack ChainOperators
opsnew = []
for op in ops:
if isinstance(op, ChainOperator):
opsnew += op._ops
else:
opsnew.append(op)
opsnew += op._ops if isinstance(op, ChainOperator) else [op]
ops = opsnew
# Step 2.5: check for NullOperators
# check for NullOperators
if any(isinstance(op, NullOperator) for op in ops):
ops = (NullOperator(ops[-1].domain, ops[0].target),)
# Step 3: collect ScalingOperators
# collect ScalingOperators
fct = 1.
opsnew = []
lastdom = ops[-1].domain
......@@ -77,7 +75,7 @@ class ChainOperator(LinearOperator):
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(fct, lastdom))
ops = opsnew
# Step 4: combine DiagonalOperators where possible
# combine DiagonalOperators where possible
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
......@@ -87,7 +85,7 @@ class ChainOperator(LinearOperator):
else:
opsnew.append(op)
ops = opsnew
# Step 5: combine BlockDiagonalOperators where possible
# combine BlockDiagonalOperators where possible
from .block_diagonal_operator import BlockDiagonalOperator
opsnew = []
for op in ops:
......@@ -137,3 +135,18 @@ class ChainOperator(LinearOperator):
for op in t_ops:
x = op.apply(x, mode)
return x
def draw_sample(self, from_inverse=False, dtype=np.float64):
from ..sugar import from_random
if len(self._ops) == 1:
return self._ops[0].draw_sample(from_inverse, dtype)
samp = from_random(random_type="normal", domain=self._domain,
dtype=dtype)
for op in self._ops:
samp = op.process_sample(samp, from_inverse)
return samp
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n"+utilities.indent(subs)
......@@ -81,9 +81,6 @@ class LinearOperator(Operator):
def _tgt(self, mode):
return self.domain if (mode & 6) else self.target
def __init__(self):
pass
def _flip_modes(self, trafo):
from .operator_adapter import OperatorAdapter
return self if trafo == 0 else OperatorAdapter(self, trafo)
......@@ -117,11 +114,8 @@ class LinearOperator(Operator):
return Operator.__rmatmul__(self, other)
def _myadd(self, other, oneg):
if self.domain == other.domain and self.target == other.target:
from .sum_operator import SumOperator
return SumOperator.make((self, other), (False, oneg))
from .relaxed_sum_operator import RelaxedSumOperator
return RelaxedSumOperator((self, -other if oneg else other))
from .sum_operator import SumOperator
return SumOperator.make((self, other), (False, oneg))
def __add__(self, other):
if isinstance(other, LinearOperator):
......
......@@ -144,16 +144,16 @@ class _OpProd(Operator):
return Linearization(lin1._val*lin2._val, op(x.jac))
class _OpSum(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
from ..sugar import domain_union
super(_OpSum, self).__init__(ops, _callingfrommake)
self._domain = domain_union([op.domain for op in self._ops])
self._target = domain_union([op.target for op in self._ops])
def apply(self, x):
res = None
for op in self._ops:
tmp = op(x.extract(op.domain))
res = tmp if res is None else res.unite(tmp)
return res
# class _OpSum(_CombinedOperator):
# def __init__(self, ops, _callingfrommake=False):
# from ..sugar import domain_union
# super(_OpSum, self).__init__(ops, _callingfrommake)
# self._domain = domain_union([op.domain for op in self._ops])
# self._target = domain_union([op.target for op in self._ops])
#
# def apply(self, x):
# res = None
# for op in self._ops:
# tmp = op(x.extract(op.domain))
# res = tmp if res is None else res.unite(tmp)
# return res
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..utilities import my_sum
from .linear_operator import LinearOperator
from ..sugar import domain_union
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
class RelaxedSumOperator(LinearOperator):
"""Class representing sums of operators with compatible domains."""
def __init__(self, ops):
self._ops = ops
self._domain = domain_union([op.domain for op in ops])
self._target = domain_union([op.target for op in ops])
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
@property
def adjoint(self):
return RelaxedSumOperator([op.adjoint for op in self._ops])
def apply(self, x, mode):
self._check_input(x, mode)
res = None
for op in self._ops:
tmp = op.apply(x.extract(op._dom(mode)), mode)
res = tmp if res is None else res.unite(tmp)
return res
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise NotImplementedError(
"cannot draw from inverse of this operator")
res = None
for op in self._ops:
tmp = op.draw_sample(from_inverse, dtype)
res = tmp if res is None else res.unite(tmp)
return res
......@@ -23,6 +23,11 @@ import numpy as np
from ..compat import *
from ..utilities import my_sum
from .linear_operator import LinearOperator
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
from .block_diagonal_operator import BlockDiagonalOperator
from collections import defaultdict
from ..sugar import domain_union
class SumOperator(LinearOperator):
......@@ -43,14 +48,8 @@ class SumOperator(LinearOperator):
def simplify(ops, neg):
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
dom = ops[0].domain
tgt = ops[0].target
for op in ops[1:]:
if dom is not op.domain or tgt is not op.target:
raise ValueError("Domain mismatch")
# Step 2: unpack SumOperators
# unpack SumOperators
opsnew = []
negnew = []
for op, ng in zip(ops, neg):
......@@ -65,75 +64,90 @@ class SumOperator(LinearOperator):
negnew.append(ng)
ops = opsnew
neg = negnew
# Step 3: collect ScalingOperators
sum = 0.
opsnew = []
negnew = []
lastdom = ops[-1].domain
# sort operators according to domains
sorted = defaultdict(list)
for op, ng in zip(ops, neg):
if isinstance(op, ScalingOperator):
sum += op._factor * (-1 if ng else 1)
else:
opsnew.append(op)
negnew.append(ng)
if sum != 0.:
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1)
opsnew[i] = opsnew[i]._add(sum)
sum = 0.
break
if sum != 0:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(sum, lastdom))
negnew.append(False)
ops = opsnew
neg = negnew
# Step 4: combine DiagonalOperators where possible
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], DiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], DiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
sorted[(op.domain, op.target)].append((op, ng))
xxops = []
xxneg = []
for opset in sorted.values():
# collect ScalingOperators
sum = 0.
opsnew = []
negnew = []
for op, ng in opset:
if isinstance(op, ScalingOperator):
sum += op._factor * (-1 if ng else 1)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
# Step 5: combine BlockDiagonalOperators where possible
from .block_diagonal_operator import BlockDiagonalOperator
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], BlockDiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], BlockDiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg, dom, tgt
negnew.append(ng)
lastdom = opset[0][0].domain
if sum != 0.:
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1)
opsnew[i] = opsnew[i]._add(sum)
sum = 0.
break
if sum != 0:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(sum, lastdom))
negnew.append(False)
ops = opsnew
neg = negnew
# Step 4: combine DiagonalOperators where possible
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], DiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], DiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
# combine BlockDiagonalOperators where possible
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], BlockDiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], BlockDiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
xxops += opsnew
xxneg += negnew
dom = domain_union([op.domain for op in xxops])
tgt = domain_union([op.target for op in xxops])
return xxops, xxneg, dom, tgt
@staticmethod
def make(ops, neg):
......@@ -154,8 +168,8 @@ class SumOperator(LinearOperator):
if len(ops) != len(neg):
raise ValueError("length mismatch between ops and neg")
ops, neg, dom, tgt = SumOperator.simplify(ops, neg)
if len(ops) == 1 and not neg[0]:
return ops[0]
if len(ops) == 1:
return -ops[0] if neg[0] else ops[0]
return SumOperator(ops, neg, dom, tgt, _callingfrommake=True)
@property
......@@ -166,18 +180,19 @@ class SumOperator(LinearOperator):
self._check_mode(mode)
res = None
for op, neg in zip(self._ops, self._neg):
tmp = op.apply(x.extract(op._dom(mode)), mode)
if res is None:
res = -op.apply(x, mode) if neg else op.apply(x, mode)
res = -tmp if neg else tmp
else:
if neg:
res = res - op.apply(x, mode)
else:
res = res + op.apply(x, mode)
res = res.flexible_addsub(tmp, neg)
return res
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise NotImplementedError(
"cannot draw from inverse of this operator")
return my_sum(map(lambda op: op.draw_sample(from_inverse, dtype),
self._ops))
res = None
for op in self._ops:
tmp = op.draw_sample(from_inverse, dtype)
res = tmp if res is None else res.unite(tmp)
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment