diff --git a/nifty6/operator_tree_optimiser.py b/nifty6/operator_tree_optimiser.py index d45023650be677a7b97c8f7ef8bc11529ec09a00..61a5fef6bc7bc05eea5a8b5c9d623f85c13c0a82 100644 --- a/nifty6/operator_tree_optimiser.py +++ b/nifty6/operator_tree_optimiser.py @@ -266,24 +266,29 @@ from .multi_field import MultiField from numpy import allclose + def optimise_operator(op): """ Merges redundant operations in the tree structure of an operator. - For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once. - Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants + For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once. + It is supposed to be used on the whole operator chain before doing minimisation. + + Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants ``ChainOp`` and ``SumOperator``. Parameters ---------- - op: Operator + op : Operator + Operator with a tree structure. Returns ------- op_optimised : Operator + Operator with same input/output, but optimised tree structure. Notes ----- - Since operators are compared by id best results are achieved when the following code + Operators are compared only by id, so best results are achieved when the following code >>> from nifty6 import UniformOperator, DomainTuple >>> uni1 = UniformOperator(DomainTuple.scalar_domain() @@ -291,16 +296,18 @@ def optimise_operator(op): >>> op = (uni1 + uni2)*(uni1 + uni2) is replaced by something comparable to + >>> uni = UniformOperator(DomainTuple.scalar_domain()) >>> uni_add = uni + uni >>> op = uni_add * uni_add - After optimisation the operator is comparable in speed to + After optimisation the operator is as fast as + >>> op = (2*uni)**2 """ op_optimised = deepcopy(op) op_optimised = _optimise_operator(op_optimised) - test_field = from_random('normal', op.domain) + test_field = from_random(op.domain) if isinstance(op(test_field), MultiField): for key in op(test_field).keys(): assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)