Skip to content
Snippets Groups Projects
Commit 9d09bd71 authored by Rouven Lemmerz's avatar Rouven Lemmerz
Browse files

Patch forgotten refactoring, fix docstrings

parent 2884f274
No related branches found
No related tags found
1 merge request!506Fix optimiser
Pipeline #75542 passed
...@@ -266,24 +266,29 @@ from .multi_field import MultiField ...@@ -266,24 +266,29 @@ from .multi_field import MultiField
from numpy import allclose from numpy import allclose
def optimise_operator(op): def optimise_operator(op):
""" """
Merges redundant operations in the tree structure of an operator. Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once. For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once.
Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants It is supposed to be used on the whole operator chain before doing minimisation.
Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants
``ChainOp`` and ``SumOperator``. ``ChainOp`` and ``SumOperator``.
Parameters Parameters
---------- ----------
op: Operator op : Operator
Operator with a tree structure.
Returns Returns
------- -------
op_optimised : Operator op_optimised : Operator
Operator with same input/output, but optimised tree structure.
Notes Notes
----- -----
Since operators are compared by id best results are achieved when the following code Operators are compared only by id, so best results are achieved when the following code
>>> from nifty6 import UniformOperator, DomainTuple >>> from nifty6 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain() >>> uni1 = UniformOperator(DomainTuple.scalar_domain()
...@@ -291,16 +296,18 @@ def optimise_operator(op): ...@@ -291,16 +296,18 @@ def optimise_operator(op):
>>> op = (uni1 + uni2)*(uni1 + uni2) >>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to is replaced by something comparable to
>>> uni = UniformOperator(DomainTuple.scalar_domain()) >>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni >>> uni_add = uni + uni
>>> op = uni_add * uni_add >>> op = uni_add * uni_add
After optimisation the operator is comparable in speed to After optimisation the operator is as fast as
>>> op = (2*uni)**2 >>> op = (2*uni)**2
""" """
op_optimised = deepcopy(op) op_optimised = deepcopy(op)
op_optimised = _optimise_operator(op_optimised) op_optimised = _optimise_operator(op_optimised)
test_field = from_random('normal', op.domain) test_field = from_random(op.domain)
if isinstance(op(test_field), MultiField): if isinstance(op(test_field), MultiField):
for key in op(test_field).keys(): for key in op(test_field).keys():
assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10) assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment