Commit 17e3b653 authored by Philipp Arras's avatar Philipp Arras

Add simplify_for_constant_input

parent aa58d6ae
......@@ -625,6 +625,11 @@ class Field(object):
raise ValueError("domain mismatch")
return self
def extract_part(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
return self
def unite(self, other):
return self+other
......
......@@ -212,6 +212,11 @@ class MultiField(object):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
def extract_part(self, subset):
if subset is self._domain:
return self
return MultiField.from_dict({key: self[key] for key in subset.keys()
if key in self})
def unite(self, other):
if self._domain is other._domain:
return self + other
......
......@@ -137,6 +137,24 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
from .operator import _ConstantOperator
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
# if len(self._ops) == 1:
......
......@@ -106,6 +106,12 @@ class Operator(NiftyMetaBase()):
def __repr__(self):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
if c_inp is None or c_inp.domain != self.domain:
return None, self
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", 'clipped_exp']:
def func(f):
......@@ -116,6 +122,63 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", 'clipped_exp']:
setattr(Operator, f, func(f))
class _ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key] for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class _ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from ..linearization import Linearization
from .simple_linear_operators import NullOperator
self._check_input(x)
if not isinstance(x, Linearization):
return self._output
return x.new(self._output, NullOperator(self._domain, self._target))
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
......@@ -176,6 +239,22 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
class _OpProd(Operator):
def __init__(self, op1, op2):
......@@ -204,6 +283,26 @@ class _OpProd(Operator):
makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, op(x.jac))
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = _ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
class _OpSum(Operator):
def __init__(self, op1, op2):
......@@ -231,3 +330,24 @@ class _OpSum(Operator):
if lin1._metric is not None and lin2._metric is not None:
res = res.add_metric(lin1._metric + lin2._metric)
return res
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = _ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
......@@ -177,3 +177,25 @@ class NullOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return self._nullfield(self._tgt(mode))
class _PartialExtractor(LinearOperator):
def __init__(self, domain, target):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(target, MultiDomain):
raise TypeError("MultiDomain expected")
self._domain = domain
self._target = target
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = x.extract(self._target)
assert res.domain is self.target
return res
fld = {key: x[key] if key in x.domain.keys() else Field.full(self._domain[key], 0.)
for key in self._domain.keys()}
assert MultiField.from_dict(fld).domain is self.domain
return MultiField.from_dict(fld)
......@@ -199,3 +199,36 @@ class SumOperator(LinearOperator):
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs)
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
from .operator import _ConstantOperator
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
f=[]
o=[]
for op in self._ops:
tf, to = op.simplify_for_constant_input(
c_inp.extract_part(op.domain))
f.append(tf)
o.append(to)
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
fullop = None
for to, n in zip(o, self._neg):
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return None, fullop
from .operator import _ConstCollector
cc = _ConstCollector()
fullop = None
for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target)
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return cc.constfield, fullop
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
from numpy.testing import assert_allclose, assert_equal
class Simplification_Tests(unittest.TestCase):
def test_simplification(self):
from nifty5.operators.operator import _ConstantOperator
f1 = ift.Field.full(ift.RGSpace(10),2.)
op = ift.FFTOperator(f1.domain)
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True)
assert_allclose(op(f1).local_data, op2(f1).local_data)
dom = {"a": ift.RGSpace(10)}
f1 = ift.full(dom,2.)
op = ift.FFTOperator(f1.domain["a"]).ducktape("a")
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True)
assert_allclose(op(f1).local_data, op2(f1).local_data)
dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)}
f1 = ift.full(dom,2.)
pdom = {"a": ift.RGSpace(10)}
f2 = ift.full(pdom,2.)
o1 = ift.FFTOperator(f1.domain["a"])
o2 = ift.FFTOperator(f1.domain["b"])
op = (o1.ducktape("a").ducktape_left("a") +
o2.ducktape("b").ducktape_left("b"))
_, op2 = op.simplify_for_constant_input(f2)
assert_equal(isinstance(op2._op1, _ConstantOperator), True)
assert_allclose(op(f1)["a"].local_data, op2(f1)["a"].local_data)
assert_allclose(op(f1)["b"].local_data, op2(f1)["b"].local_data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment