Commit 59a72c4b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'patch_optimiser' into 'NIFTy_6'

Fix optimiser

See merge request !506
parents 2884f274 1362c613
Pipeline #75544 passed with stages
in 10 minutes and 21 seconds
......@@ -266,24 +266,29 @@ from .multi_field import MultiField
from numpy import allclose
def optimise_operator(op):
Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once.
Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants
For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once.
It is supposed to be used on the whole operator chain before doing minimisation.
Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants
``ChainOp`` and ``SumOperator``.
op: Operator
op : Operator
Operator with a tree structure.
op_optimised : Operator
Operator with same input/output, but optimised tree structure.
Since operators are compared by id best results are achieved when the following code
Operators are compared only by id, so best results are achieved when the following code
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain()
......@@ -291,16 +296,18 @@ def optimise_operator(op):
>>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to
>>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni
>>> op = uni_add * uni_add
After optimisation the operator is comparable in speed to
After optimisation the operator is as fast as
>>> op = (2*uni)**2
op_optimised = deepcopy(op)
op_optimised = _optimise_operator(op_optimised)
test_field = from_random('normal', op.domain)
test_field = from_random(op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)
......@@ -39,7 +39,6 @@ class CountingOp(ift.Operator):
def test_operator_tree_optimiser():
dom = ift.RGSpace(10, harmonic=True)
hdom = dom.get_default_codomain()
cop1 = CountingOp(dom)
op1 = (ift.UniformOperator(dom, -1, 2)@cop1).ducktape('a')
cop2 = CountingOp(dom)
......@@ -58,3 +57,5 @@ def test_operator_tree_optimiser():
op = ift.operator_tree_optimiser._optimise_operator(op)
assert_allclose(op(fld).val, op_orig(fld).val, rtol=np.finfo(np.float64).eps)
assert_(1 == ( (cop4.count-1) * cop3.count * cop2.count * cop1.count))
# test testing
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment