Commit 8dc99b9f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'preptw' into 'NIFTy_7'

Implemente pre-nonlineariaties

See merge request ift/nifty!610
parents e8ba6f2e 6b2b4adb
Pipeline #96747 passed with stages
in 12 minutes and 22 seconds
......@@ -330,6 +330,9 @@ class Operator(metaclass=NiftyMeta):
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(, op, *args, **kwargs), self))
def ptw_pre(self, op, *args, **kwargs):
return _OpChain.make((self, _FunctionApplier(self.domain, op, *args, **kwargs)))
for f in pointwise.ptw_dict.keys():
def func(f):
......@@ -337,6 +340,11 @@ for f in pointwise.ptw_dict.keys():
return self.ptw(f, *args, **kwargs)
return func2
setattr(Operator, f, func(f))
def func(f):
def func2(self, *args, **kwargs):
return self.ptw_pre(f, *args, **kwargs)
return func2
setattr(Operator, f + "_pre", func(f))
class _FunctionApplier(Operator):
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2021 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
import nifty7 as ift
from ..common import setup_function, teardown_function
pmp = pytest.mark.parametrize
@pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', 'expm1', 'softplus',
('power', 2.), ('exponentiate', 1.1)
def test_ptw_pre(f):
if not isinstance(f, tuple):
f = (f,)
op = ift.FFTOperator(ift.RGSpace(10))
op0 = op @ ift.ScalingOperator(op.domain, 1.).ptw(*f)
op1 = op.ptw_pre(*f)
pos = ift.from_random(op0.domain)
if f[0] in ['log', 'sqrt', 'log10', 'log1p']:
pos = pos.exp()
ift.extra.assert_equal(op0(pos), op1(pos))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment