Commit 59a72c4b by Martin Reinecke

### Merge branch 'patch_optimiser' into 'NIFTy_6'

```Fix optimiser

See merge request !506```
parents 2884f274 1362c613
Pipeline #75544 passed with stages
in 10 minutes and 21 seconds
 ... ... @@ -266,24 +266,29 @@ from .multi_field import MultiField from numpy import allclose def optimise_operator(op): """ Merges redundant operations in the tree structure of an operator. For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once. Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once. It is supposed to be used on the whole operator chain before doing minimisation. Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants ``ChainOp`` and ``SumOperator``. Parameters ---------- op: Operator op : Operator Operator with a tree structure. Returns ------- op_optimised : Operator Operator with same input/output, but optimised tree structure. Notes ----- Since operators are compared by id best results are achieved when the following code Operators are compared only by id, so best results are achieved when the following code >>> from nifty6 import UniformOperator, DomainTuple >>> uni1 = UniformOperator(DomainTuple.scalar_domain() ... ... @@ -291,16 +296,18 @@ def optimise_operator(op): >>> op = (uni1 + uni2)*(uni1 + uni2) is replaced by something comparable to >>> uni = UniformOperator(DomainTuple.scalar_domain()) >>> uni_add = uni + uni >>> op = uni_add * uni_add After optimisation the operator is comparable in speed to After optimisation the operator is as fast as >>> op = (2*uni)**2 """ op_optimised = deepcopy(op) op_optimised = _optimise_operator(op_optimised) test_field = from_random('normal', op.domain) test_field = from_random(op.domain) if isinstance(op(test_field), MultiField): for key in op(test_field).keys(): assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10) ... ...
 ... ... @@ -39,7 +39,6 @@ class CountingOp(ift.Operator): def test_operator_tree_optimiser(): dom = ift.RGSpace(10, harmonic=True) hdom = dom.get_default_codomain() cop1 = CountingOp(dom) op1 = (ift.UniformOperator(dom, -1, 2)@cop1).ducktape('a') cop2 = CountingOp(dom) ... ... @@ -58,3 +57,5 @@ def test_operator_tree_optimiser(): op = ift.operator_tree_optimiser._optimise_operator(op) assert_allclose(op(fld).val, op_orig(fld).val, rtol=np.finfo(np.float64).eps) assert_(1 == ( (cop4.count-1) * cop3.count * cop2.count * cop1.count)) # test testing ift.optimise_operator(op_orig)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!