Commit 08e0ae66 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

adjust a few demos

parent 409244a8
......@@ -53,7 +53,7 @@ if __name__ == '__main__':
A = pd(a)
# Set up a sky model
sky = lambda inp: HT(A*inp).positive_tanh()
sky = HT.chain(ift.makeOp(A)).positive_tanh()
GR = ift.GeometryRemover(position_space)
# Set up instrumental response
......
......@@ -70,7 +70,7 @@ if __name__ == '__main__':
A = pd(a)
# Set up a sky model
sky = lambda inp: (HT(inp*A)).exp()
sky = HT.chain(ift.makeOp(A)).exp()
M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space)
......
......@@ -25,10 +25,7 @@ class Operator(NiftyMetaBase()):
def __matmul__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return OpChain.make((self, x))
ops1 = self._ops if isinstance(self, OpChain) else (self,)
ops2 = x._ops if isinstance(x, OpChain) else (x,)
return OpChain(ops1+ops2)
return _OpChain.make((self, x))
def chain(self, x):
res = self.__matmul__(x)
......@@ -52,6 +49,25 @@ class Operator(NiftyMetaBase()):
raise NotImplementedError
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
def func(f):
def func2(self):
fa = _FunctionApplier(self._target, f)
return _OpChain.make((fa, self))
return func2
setattr(Operator, f, func(f))
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._funcname = funcname
def __call__(self, x):
return getattr(x, self._funcname)()
class _CombinedOperator(Operator):
def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake:
......@@ -62,7 +78,7 @@ class _CombinedOperator(Operator):
def unpack(cls, ops, res):
for op in ops:
if isinstance(op, cls):
res = cls.unpack(op, res)
res = cls.unpack(op._ops, res)
else:
res = res + [op]
return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment