From fab8b8b036bd3bae2e8fc0499ffc86a6de4dcbf4 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Thu, 4 Oct 2018 15:47:05 +0200 Subject: [PATCH] Move ValueInserter into separate file --- nifty5/__init__.py | 3 +- nifty5/operators/simple_linear_operators.py | 35 ------------ nifty5/operators/value_inserter.py | 60 +++++++++++++++++++++ 3 files changed, 62 insertions(+), 36 deletions(-) create mode 100644 nifty5/operators/value_inserter.py diff --git a/nifty5/__init__.py b/nifty5/__init__.py index df1e4b47b..81604dbcc 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -46,7 +46,8 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.outer_product_operator import OuterProduct from .operators.simple_linear_operators import ( VdotOperator, ConjugationOperator, Realizer, - FieldAdapter, GeometryRemover, NullOperator, ValueInserter) + FieldAdapter, GeometryRemover, NullOperator) +from .operators.value_inserter import ValueInserter from .operators.energy_operators import ( EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence) diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 1bce95fc8..f048eab91 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -18,8 +18,6 @@ from __future__ import absolute_import, division, print_function -import numpy as np - from ..compat import * from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain @@ -142,36 +140,3 @@ class NullOperator(LinearOperator): def apply(self, x, mode): self._check_input(x, mode) return self._nullfield(self._tgt(mode)) - - -class ValueInserter(LinearOperator): - """Operator which inserts one value into a field. - - Parameters - ---------- - target : Domain, tuple of Domain or DomainTuple - index : tuple - The index of the target into which the value of the domain shall be - written. - """ - - def __init__(self, target, index): - from ..sugar import makeDomain - self._domain = makeDomain(UnstructuredDomain(1)) - self._target = DomainTuple.make(target) - if not isinstance(index, tuple): - raise TypeError - self._index = index - self._capability = self.TIMES | self.ADJOINT_TIMES - # Check if index is in bounds - np.empty(self.target.shape)[self._index] - - def apply(self, x, mode): - self._check_input(x, mode) - x = x.to_global_data() - if mode == self.TIMES: - res = np.zeros(self.target.shape, dtype=x.dtype) - res[self._index] = x - else: - res = x[self._index] - return Field.from_global_data(self._tgt(mode), res) diff --git a/nifty5/operators/value_inserter.py b/nifty5/operators/value_inserter.py new file mode 100644 index 000000000..f9da44800 --- /dev/null +++ b/nifty5/operators/value_inserter.py @@ -0,0 +1,60 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +from ..compat import * +from ..domain_tuple import DomainTuple +from ..domains.unstructured_domain import UnstructuredDomain +from ..field import Field +from .linear_operator import LinearOperator + + +class ValueInserter(LinearOperator): + """Operator which inserts one value into a field. + + Parameters + ---------- + target : Domain, tuple of Domain or DomainTuple + index : tuple + The index of the target into which the value of the domain shall be + written. + """ + + def __init__(self, target, index): + from ..sugar import makeDomain + self._domain = makeDomain(UnstructuredDomain(1)) + self._target = DomainTuple.make(target) + if not isinstance(index, tuple): + raise TypeError + self._index = index + self._capability = self.TIMES | self.ADJOINT_TIMES + # Check if index is in bounds + np.empty(self.target.shape)[self._index] + + def apply(self, x, mode): + self._check_input(x, mode) + x = x.to_global_data() + if mode == self.TIMES: + res = np.zeros(self.target.shape, dtype=x.dtype) + res[self._index] = x + else: + res = x[self._index] + return Field.from_global_data(self._tgt(mode), res) -- GitLab