Commit 59a72c4b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'patch_optimiser' into 'NIFTy_6'

Fix optimiser

See merge request !506
parents 2884f274 1362c613
Pipeline #75544 passed with stages
in 10 minutes and 21 seconds
......@@ -266,24 +266,29 @@ from .multi_field import MultiField
from numpy import allclose
def optimise_operator(op):
"""
Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once.
Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants
For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once.
It is supposed to be used on the whole operator chain before doing minimisation.
Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants
``ChainOp`` and ``SumOperator``.
Parameters
----------
op: Operator
op : Operator
Operator with a tree structure.
Returns
-------
op_optimised : Operator
Operator with same input/output, but optimised tree structure.
Notes
-----
Since operators are compared by id best results are achieved when the following code
Operators are compared only by id, so best results are achieved when the following code
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain()
......@@ -291,16 +296,18 @@ def optimise_operator(op):
>>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to
>>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni
>>> op = uni_add * uni_add
After optimisation the operator is comparable in speed to
After optimisation the operator is as fast as
>>> op = (2*uni)**2
"""
op_optimised = deepcopy(op)
op_optimised = _optimise_operator(op_optimised)
test_field = from_random('normal', op.domain)
test_field = from_random(op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)
......
......@@ -39,7 +39,6 @@ class CountingOp(ift.Operator):
def test_operator_tree_optimiser():
dom = ift.RGSpace(10, harmonic=True)
hdom = dom.get_default_codomain()
cop1 = CountingOp(dom)
op1 = (ift.UniformOperator(dom, -1, 2)@cop1).ducktape('a')
cop2 = CountingOp(dom)
......@@ -58,3 +57,5 @@ def test_operator_tree_optimiser():
op = ift.operator_tree_optimiser._optimise_operator(op)
assert_allclose(op(fld).val, op_orig(fld).val, rtol=np.finfo(np.float64).eps)
assert_(1 == ( (cop4.count-1) * cop3.count * cop2.count * cop1.count))
# test testing
ift.optimise_operator(op_orig)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment