diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index b4bd8f21523cd060bebfc62b2a1bd2ade6100694..01f6f4135463b634702146af0da020c45bfbf37b 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -53,7 +53,7 @@ if __name__ == '__main__': A = pd(a) # Set up a sky model - sky = lambda inp: HT(A*inp).positive_tanh() + sky = HT.chain(ift.makeOp(A)).positive_tanh() GR = ift.GeometryRemover(position_space) # Set up instrumental response diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index 46f5ad37417a2ff563d1c6ff3620dbf4f3343f51..79fc7bb6baef97c582c1a37e9f25bee2a3082e3e 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -70,7 +70,7 @@ if __name__ == '__main__': A = pd(a) # Set up a sky model - sky = lambda inp: (HT(inp*A)).exp() + sky = HT.chain(ift.makeOp(A)).exp() M = ift.DiagonalOperator(exposure) GR = ift.GeometryRemover(position_space) diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 99f54d4947e4ee23dea9fd356a521540cd255ab6..3135a7f7df7cc73170cc9d670a9d73d4a4897cf0 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -25,10 +25,7 @@ class Operator(NiftyMetaBase()): def __matmul__(self, x): if not isinstance(x, Operator): return NotImplemented - return OpChain.make((self, x)) - ops1 = self._ops if isinstance(self, OpChain) else (self,) - ops2 = x._ops if isinstance(x, OpChain) else (x,) - return OpChain(ops1+ops2) + return _OpChain.make((self, x)) def chain(self, x): res = self.__matmul__(x) @@ -52,6 +49,25 @@ class Operator(NiftyMetaBase()): raise NotImplementedError +for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: + def func(f): + def func2(self): + fa = _FunctionApplier(self._target, f) + return _OpChain.make((fa, self)) + return func2 + setattr(Operator, f, func(f)) + + +class _FunctionApplier(Operator): + def __init__(self, domain, funcname): + from ..sugar import makeDomain + self._domain = self._target = makeDomain(domain) + self._funcname = funcname + + def __call__(self, x): + return getattr(x, self._funcname)() + + class _CombinedOperator(Operator): def __init__(self, ops, _callingfrommake=False): if not _callingfrommake: @@ -62,7 +78,7 @@ class _CombinedOperator(Operator): def unpack(cls, ops, res): for op in ops: if isinstance(op, cls): - res = cls.unpack(op, res) + res = cls.unpack(op._ops, res) else: res = res + [op] return res