Commit 3c2feb14 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge remote-tracking branch 'origin/NIFTy_5' into adjust_variances_but_right

parents f6b56291 8af5299c
......@@ -36,8 +36,10 @@ build_docker_from_cache:
test_python2_with_coverage:
stage: test
variables:
OMPI_MCA_btl_vader_single_copy_mechanism: none
script:
- mpiexec -n 2 --bind-to none pytest -q test 2> /dev/null
- mpiexec -n 2 --bind-to none pytest -q test
- pytest -q --cov=nifty5 test
- >
python -m coverage report --omit "*plotting*,*distributed_do*"
......@@ -46,9 +48,11 @@ test_python2_with_coverage:
test_python3:
stage: test
variables:
OMPI_MCA_btl_vader_single_copy_mechanism: none
script:
- pytest-3 -q
- mpiexec -n 2 --bind-to none pytest-3 -q 2> /dev/null
- mpiexec -n 2 --bind-to none pytest-3 -q
pages:
stage: release
......
......@@ -22,6 +22,7 @@ from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
from .operators.contraction_operator import ContractionOperator
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
......
......@@ -150,6 +150,8 @@ class DiagonalOperator(EndomorphicOperator):
return Field.from_local_data(x.domain, x.local_data/xdiag)
def _flip_modes(self, trafo):
if trafo == self.ADJOINT_BIT and not self._complex: # shortcut
return self
xdiag = self._ldiag
if self._complex and (trafo & self.ADJOINT_BIT):
xdiag = xdiag.conj()
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from .linear_operator import LinearOperator
class DomainTupleFieldInserter(LinearOperator):
def __init__(self, domain, new_space, index, position):
'''Writes the content of a field into one slice of a DomainTuple.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
index : Integer
Index at which new_space shall be added to domain.
position : tuple
Slice in new_space in which the input field shall be written into.
'''
self._domain = DomainTuple.make(domain)
tgt = list(self.domain)
tgt.insert(index, new_space)
self._target = DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
fst_dims = sum(len(dd.shape) for dd in self.domain[:index])
nshp = new_space.shape
if len(position) != len(nshp):
raise ValueError("shape mismatch between new_space and position")
for s, p in zip(nshp, position):
if p < 0 or p >= s:
raise ValueError("bad position value")
self._slc = (slice(None),)*fst_dims + position
def apply(self, x, mode):
self._check_input(x, mode)
# FIXME Make fully MPI compatible without global_data
if mode == self.TIMES:
res = np.zeros(self.target.shape, dtype=x.dtype)
res[self._slc] = x.to_global_data()
return Field.from_global_data(self.target, res)
else:
return Field.from_global_data(self.domain,
x.to_global_data()[self._slc])
......@@ -69,7 +69,6 @@ class FieldZeroPadder(LinearOperator):
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] = x[i1]
# if (x.shape[d] & 1) == 0: # even number of pixels
# print (Nyquist, x.shape[d]-Nyquist)
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
# i1 = idx+(-Nyquist,)
......
......@@ -194,6 +194,14 @@ class Consistency_Tests(unittest.TestCase):
op = ift.ContractionOperator(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype)
def testDomainTupleFieldInserter(self):
domain = ift.DomainTuple.make((ift.UnstructuredDomain(12),
ift.RGSpace([4, 22])))
new_space = ift.UnstructuredDomain(7)
pos = (5,)
op = ift.DomainTupleFieldInserter(domain, new_space, 0, pos)
ift.extra.consistency_check(op)
@expand(product([0, 2], [np.float64, np.complex128]))
def testSymmetrizingOperator(self, space, dtype):
dom = (ift.LogRGSpace(10, [2.], [1.]), ift.UnstructuredDomain(13),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment