From 08e0ae665357d737f2494ba34fe9bbbd475b35fd Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Thu, 2 Aug 2018 11:47:57 +0200 Subject: [PATCH] adjust a few demos --- demos/bernoulli_demo.py | 2 +- demos/getting_started_2.py | 2 +- nifty5/operators/operator.py | 26 +++++++++++++++++++++----- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index b4bd8f21..01f6f413 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -53,7 +53,7 @@ if __name__ == '__main__': A = pd(a) # Set up a sky model - sky = lambda inp: HT(A*inp).positive_tanh() + sky = HT.chain(ift.makeOp(A)).positive_tanh() GR = ift.GeometryRemover(position_space) # Set up instrumental response diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index 46f5ad37..79fc7bb6 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -70,7 +70,7 @@ if __name__ == '__main__': A = pd(a) # Set up a sky model - sky = lambda inp: (HT(inp*A)).exp() + sky = HT.chain(ift.makeOp(A)).exp() M = ift.DiagonalOperator(exposure) GR = ift.GeometryRemover(position_space) diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 99f54d49..3135a7f7 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -25,10 +25,7 @@ class Operator(NiftyMetaBase()): def __matmul__(self, x): if not isinstance(x, Operator): return NotImplemented - return OpChain.make((self, x)) - ops1 = self._ops if isinstance(self, OpChain) else (self,) - ops2 = x._ops if isinstance(x, OpChain) else (x,) - return OpChain(ops1+ops2) + return _OpChain.make((self, x)) def chain(self, x): res = self.__matmul__(x) @@ -52,6 +49,25 @@ class Operator(NiftyMetaBase()): raise NotImplementedError +for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: + def func(f): + def func2(self): + fa = _FunctionApplier(self._target, f) + return _OpChain.make((fa, self)) + return func2 + setattr(Operator, f, func(f)) + + +class _FunctionApplier(Operator): + def __init__(self, domain, funcname): + from ..sugar import makeDomain + self._domain = self._target = makeDomain(domain) + self._funcname = funcname + + def __call__(self, x): + return getattr(x, self._funcname)() + + class _CombinedOperator(Operator): def __init__(self, ops, _callingfrommake=False): if not _callingfrommake: @@ -62,7 +78,7 @@ class _CombinedOperator(Operator): def unpack(cls, ops, res): for op in ops: if isinstance(op, cls): - res = cls.unpack(op, res) + res = cls.unpack(op._ops, res) else: res = res + [op] return res -- GitLab