Commit 4b54c37b authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix Distributor in CorrelatedFields

parent 3c27689a
Pipeline #75297 passed with stages
in 26 minutes and 11 seconds
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Philipp Frank, Philipp Arras, Philipp Haim # Authors: Philipp Frank, Philipp Arras, Philipp Haim
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -25,7 +25,6 @@ from ..domain_tuple import DomainTuple ...@@ -25,7 +25,6 @@ from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field from ..field import Field
from ..linearization import Linearization
from ..logger import logger from ..logger import logger
from ..multi_field import MultiField from ..multi_field import MultiField
from ..operators.adder import Adder from ..operators.adder import Adder
...@@ -244,10 +243,9 @@ class _SpecialSum(EndomorphicOperator): ...@@ -244,10 +243,9 @@ class _SpecialSum(EndomorphicOperator):
class _Distributor(LinearOperator): class _Distributor(LinearOperator):
def __init__(self, dofdex, domain, target): def __init__(self, dofdex, domain, target):
self._dofdex = dofdex self._dofdex = np.array(dofdex)
self._target = DomainTuple.make(target)
self._target = makeDomain(target) self._domain = DomainTuple.make(domain)
self._domain = makeDomain(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
...@@ -256,7 +254,7 @@ class _Distributor(LinearOperator): ...@@ -256,7 +254,7 @@ class _Distributor(LinearOperator):
if mode == self.TIMES: if mode == self.TIMES:
res = x[self._dofdex] res = x[self._dofdex]
else: else:
res = np.empty(self._tgt(mode).shape) res = np.zeros(self._tgt(mode).shape, dtype=x.dtype)
res[self._dofdex] = x res[self._dofdex] = x
return makeField(self._tgt(mode), res) return makeField(self._tgt(mode), res)
......
...@@ -326,3 +326,11 @@ def testSlowFieldAdapter(seed): ...@@ -326,3 +326,11 @@ def testSlowFieldAdapter(seed):
dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)} dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
op = ift.operators.simple_linear_operators._SlowFieldAdapter(dom, 'a') op = ift.operators.simple_linear_operators._SlowFieldAdapter(dom, 'a')
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
@pmp('dofdex', [(0,), (1,), (0, 1), (1, 0)])
def testCorFldDistr(dofdex):
tgt = ift.UnstructuredDomain(len(dofdex))
dom = ift.UnstructuredDomain(2)
op = ift.library.correlated_fields._Distributor(dofdex, dom, tgt)
ift.extra.consistency_check(op)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment