Commit 08e0ae66 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

adjust a few demos

parent 409244a8
...@@ -53,7 +53,7 @@ if __name__ == '__main__': ...@@ -53,7 +53,7 @@ if __name__ == '__main__':
A = pd(a) A = pd(a)
# Set up a sky model # Set up a sky model
sky = lambda inp: HT(A*inp).positive_tanh() sky = HT.chain(ift.makeOp(A)).positive_tanh()
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
# Set up instrumental response # Set up instrumental response
......
...@@ -70,7 +70,7 @@ if __name__ == '__main__': ...@@ -70,7 +70,7 @@ if __name__ == '__main__':
A = pd(a) A = pd(a)
# Set up a sky model # Set up a sky model
sky = lambda inp: (HT(inp*A)).exp() sky = HT.chain(ift.makeOp(A)).exp()
M = ift.DiagonalOperator(exposure) M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
......
...@@ -25,10 +25,7 @@ class Operator(NiftyMetaBase()): ...@@ -25,10 +25,7 @@ class Operator(NiftyMetaBase()):
def __matmul__(self, x): def __matmul__(self, x):
if not isinstance(x, Operator): if not isinstance(x, Operator):
return NotImplemented return NotImplemented
return OpChain.make((self, x)) return _OpChain.make((self, x))
ops1 = self._ops if isinstance(self, OpChain) else (self,)
ops2 = x._ops if isinstance(x, OpChain) else (x,)
return OpChain(ops1+ops2)
def chain(self, x): def chain(self, x):
res = self.__matmul__(x) res = self.__matmul__(x)
...@@ -52,6 +49,25 @@ class Operator(NiftyMetaBase()): ...@@ -52,6 +49,25 @@ class Operator(NiftyMetaBase()):
raise NotImplementedError raise NotImplementedError
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
def func(f):
def func2(self):
fa = _FunctionApplier(self._target, f)
return _OpChain.make((fa, self))
return func2
setattr(Operator, f, func(f))
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._funcname = funcname
def __call__(self, x):
return getattr(x, self._funcname)()
class _CombinedOperator(Operator): class _CombinedOperator(Operator):
def __init__(self, ops, _callingfrommake=False): def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake: if not _callingfrommake:
...@@ -62,7 +78,7 @@ class _CombinedOperator(Operator): ...@@ -62,7 +78,7 @@ class _CombinedOperator(Operator):
def unpack(cls, ops, res): def unpack(cls, ops, res):
for op in ops: for op in ops:
if isinstance(op, cls): if isinstance(op, cls):
res = cls.unpack(op, res) res = cls.unpack(op._ops, res)
else: else:
res = res + [op] res = res + [op]
return res return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment