From 9d09bd71aa51a0d508f081402d3df29ff8fd992d Mon Sep 17 00:00:00 2001 From: Rouven Lemmerz <lemmerz@mpa-garching.mpg.de> Date: Tue, 26 May 2020 13:49:52 +0200 Subject: [PATCH] Patch forgotten refactoring, fix docstrings --- nifty6/operator_tree_optimiser.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/nifty6/operator_tree_optimiser.py b/nifty6/operator_tree_optimiser.py index d45023650..61a5fef6b 100644 --- a/nifty6/operator_tree_optimiser.py +++ b/nifty6/operator_tree_optimiser.py @@ -266,24 +266,29 @@ from .multi_field import MultiField from numpy import allclose + def optimise_operator(op): """ Merges redundant operations in the tree structure of an operator. - For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once. - Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants + For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once. + It is supposed to be used on the whole operator chain before doing minimisation. + + Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants ``ChainOp`` and ``SumOperator``. Parameters ---------- - op: Operator + op : Operator + Operator with a tree structure. Returns ------- op_optimised : Operator + Operator with same input/output, but optimised tree structure. Notes ----- - Since operators are compared by id best results are achieved when the following code + Operators are compared only by id, so best results are achieved when the following code >>> from nifty6 import UniformOperator, DomainTuple >>> uni1 = UniformOperator(DomainTuple.scalar_domain() @@ -291,16 +296,18 @@ def optimise_operator(op): >>> op = (uni1 + uni2)*(uni1 + uni2) is replaced by something comparable to + >>> uni = UniformOperator(DomainTuple.scalar_domain()) >>> uni_add = uni + uni >>> op = uni_add * uni_add - After optimisation the operator is comparable in speed to + After optimisation the operator is as fast as + >>> op = (2*uni)**2 """ op_optimised = deepcopy(op) op_optimised = _optimise_operator(op_optimised) - test_field = from_random('normal', op.domain) + test_field = from_random(op.domain) if isinstance(op(test_field), MultiField): for key in op(test_field).keys(): assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10) -- GitLab