Scheduled maintenance on Monday 2019-06-24 between 10:00-11:00 CEST

Commit abc660d4 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'simplify_for_const' into 'NIFTy_5'

Simplify for const

See merge request !295
parents 11f686dd 61c290f5
Pipeline #44819 passed with stages
in 19 minutes and 22 seconds
......@@ -626,6 +626,11 @@ class Field(object):
raise ValueError("domain mismatch")
return self
def extract_part(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
return self
def unite(self, other):
return self+other
......
......@@ -217,6 +217,12 @@ class MultiField(object):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
def extract_part(self, subset):
if subset is self._domain:
return self
return MultiField.from_dict({key: self[key] for key in subset.keys()
if key in self})
def unite(self, other):
"""Merges two MultiFields on potentially different MultiDomains.
......
......@@ -138,6 +138,17 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
# if len(self._ops) == 1:
......
......@@ -146,6 +146,17 @@ class Operator(metaclass=NiftyMeta):
def __repr__(self):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
......@@ -157,6 +168,72 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
setattr(Operator, f, func(f))
class _ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class _ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from ..linearization import Linearization
from .simple_linear_operators import NullOperator
from ..domain_tuple import DomainTuple
self._check_input(x)
if not isinstance(x, Linearization):
return self._output
if x.want_metric and self._target is DomainTuple.scalar_domain():
met = NullOperator(self._domain, self._domain)
else:
met = None
return x.new(self._output, NullOperator(self._domain, self._target),
met)
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
......@@ -229,6 +306,17 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "_OpChain:\n" + indent(subs)
......@@ -261,6 +349,21 @@ class _OpProd(Operator):
makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, op(x.jac))
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = _ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpProd:\n"+indent(subs)
......@@ -293,6 +396,21 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric + lin2._metric)
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = _ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpSum:\n"+indent(subs)
......@@ -35,14 +35,6 @@ class ScalingOperator(EndomorphicOperator):
-----
:class:`Operator` supports the multiplication with a scalar. So one does
not need instantiate :class:`ScalingOperator` explicitly in most cases.
Formally, this operator always supports all operation modes (times,
adjoint_times, inverse_times and inverse_adjoint_times), even if `factor`
is 0 or infinity. It is the user's responsibility to apply the operator
only in appropriate ways (e.g. call inverse_times only if `factor` is
nonzero).
This shortcoming will hopefully be fixed in the future.
"""
def __init__(self, factor, domain):
......@@ -52,7 +44,10 @@ class ScalingOperator(EndomorphicOperator):
raise TypeError("Scalar required")
self._factor = factor
self._domain = makeDomain(domain)
self._capability = self._all_ops
if self._factor == 0.:
self._capability = self.TIMES | self.ADJOINT_TIMES
else:
self._capability = self._all_ops
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -315,3 +315,23 @@ class NullOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return self._nullfield(self._tgt(mode))
class _PartialExtractor(LinearOperator):
def __init__(self, domain, target):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(target, MultiDomain):
raise TypeError("MultiDomain expected")
self._domain = domain
self._target = target
for key in self._target.keys():
if not (self._domain[key] is not self._target[key]):
raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
......@@ -23,6 +23,7 @@ from ..sugar import domain_union
from ..utilities import indent
from .block_diagonal_operator import BlockDiagonalOperator
from .linear_operator import LinearOperator
from .simple_linear_operators import NullOperator
class SumOperator(LinearOperator):
......@@ -59,6 +60,9 @@ class SumOperator(LinearOperator):
negnew += [not n for n in op._neg]
else:
negnew += list(op._neg)
# FIXME: this needs some more work to keep the domain and target unchanged!
# elif isinstance(op, NullOperator):
# pass
else:
opsnew.append(op)
negnew.append(ng)
......@@ -193,6 +197,9 @@ class SumOperator(LinearOperator):
"cannot draw from inverse of this operator")
res = None
for op in self._ops:
from .simple_linear_operators import NullOperator
if isinstance(op, NullOperator):
continue
tmp = op.draw_sample(from_inverse, dtype)
res = tmp if res is None else res.unite(tmp)
return res
......@@ -200,3 +207,29 @@ class SumOperator(LinearOperator):
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
f = []
o = []
for op in self._ops:
tf, to = op.simplify_for_constant_input(
c_inp.extract_part(op.domain))
f.append(tf)
o.append(to)
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
fullop = None
for to, n in zip(o, self._neg):
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return None, fullop
from .operator import _ConstCollector
cc = _ConstCollector()
fullop = None
for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target)
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return cc.constfield, fullop
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose, assert_equal
import nifty5 as ift
def test_simplification():
from nifty5.operators.operator import _ConstantOperator
f1 = ift.Field.full(ift.RGSpace(10),2.)
op = ift.FFTOperator(f1.domain)
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True)
assert_allclose(op(f1).local_data, op2(f1).local_data)
dom = {"a": ift.RGSpace(10)}
f1 = ift.full(dom,2.)
op = ift.FFTOperator(f1.domain["a"]).ducktape("a")
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True)
assert_allclose(op(f1).local_data, op2(f1).local_data)
dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)}
f1 = ift.full(dom,2.)
pdom = {"a": ift.RGSpace(10)}
f2 = ift.full(pdom,2.)
o1 = ift.FFTOperator(f1.domain["a"])
o2 = ift.FFTOperator(f1.domain["b"])
op = (o1.ducktape("a").ducktape_left("a") +
o2.ducktape("b").ducktape_left("b"))
_, op2 = op.simplify_for_constant_input(f2)
assert_equal(isinstance(op2._op1, _ConstantOperator), True)
assert_allclose(op(f1)["a"].local_data, op2(f1)["a"].local_data)
assert_allclose(op(f1)["b"].local_data, op2(f1)["b"].local_data)
lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True)
assert_allclose(op(lin).val["a"].local_data,
op2(lin).val["a"].local_data)
assert_allclose(op(lin).val["b"].local_data,
op2(lin).val["b"].local_data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment