Commit f685631a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'local_nonlinearities' into 'NIFTy_5'

added a number of local nonlinear functions

See merge request ift/nifty-dev!144
parents 4c0bd584 a17514ff
...@@ -54,7 +54,7 @@ if __name__ == '__main__': ...@@ -54,7 +54,7 @@ if __name__ == '__main__':
A = ift.create_power_operator(harmonic_space, sqrtpspec) A = ift.create_power_operator(harmonic_space, sqrtpspec)
# Set up a sky model and instrumental response # Set up a sky model and instrumental response
sky = ift.positive_tanh(HT(A)) sky = ift.sigmoid(HT(A))
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
R = GR R = GR
......
...@@ -76,7 +76,7 @@ if __name__ == '__main__': ...@@ -76,7 +76,7 @@ if __name__ == '__main__':
# correlated_field = ift.CorrelatedField(position_space, A) # correlated_field = ift.CorrelatedField(position_space, A)
# Apply a nonlinearity # Apply a nonlinearity
signal = ift.positive_tanh(correlated_field) signal = ift.sigmoid(correlated_field)
# Build the line-of-sight response and define signal response # Build the line-of-sight response and define signal response
LOS_starts, LOS_ends = random_los(100) if mode == 1 else radial_los(100) LOS_starts, LOS_ends = random_los(100) if mode == 1 else radial_los(100)
......
...@@ -136,7 +136,7 @@ A :class:`Field` object consists of the following components: ...@@ -136,7 +136,7 @@ A :class:`Field` object consists of the following components:
- a data type (e.g. numpy.float64) - a data type (e.g. numpy.float64)
- an array containing the actual values - an array containing the actual values
Usually, the array is stored in the for of a ``numpy.ndarray``, but for very Usually, the array is stored in the form of a ``numpy.ndarray``, but for very
resource-intensive tasks NIFTy also provides an alternative storage method to resource-intensive tasks NIFTy also provides an alternative storage method to
be used with distributed memory processing. be used with distributed memory processing.
......
...@@ -31,7 +31,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -31,7 +31,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"] "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "clip"]
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
...@@ -211,6 +212,9 @@ class data_object(object): ...@@ -211,6 +212,9 @@ class data_object(object):
else: else:
return data_object(self._shape, tval, self._distaxis) return data_object(self._shape, tval, self._distaxis)
def clip(self, min=None, max=None):
return data_object(self._shape, np.clip(self._data, min, max))
def __neg__(self): def __neg__(self):
return data_object(self._shape, -self._data, self._distaxis) return data_object(self._shape, -self._data, self._distaxis)
...@@ -292,7 +296,8 @@ def _math_helper(x, function, out): ...@@ -292,7 +296,8 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__] _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]:
def func(f): def func(f):
def func2(x, out=None): def func2(x, out=None):
return _math_helper(x, f, out) return _math_helper(x, f, out)
...@@ -300,8 +305,8 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: ...@@ -300,8 +305,8 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f)) setattr(_current_module, f, func(f))
def clipped_exp(a): def clip(x, a_min=None, a_max=None):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis)) return data_object(x.shape, np.clip(x.data, a_min, a_max), x.distaxis)
def from_object(object, dtype, copy, set_locked): def from_object(object, dtype, copy, set_locked):
......
...@@ -21,7 +21,8 @@ import numpy as np ...@@ -21,7 +21,8 @@ import numpy as np
from numpy import empty, empty_like, exp, full, log from numpy import empty, empty_like, exp, full, log
from numpy import ndarray as data_object from numpy import ndarray as data_object
from numpy import ones, sqrt, tanh, vdot, zeros from numpy import ones, sqrt, tanh, vdot, zeros
from numpy import sin, cos, tan, sinh, cosh, sinc
from numpy import absolute, sign, clip
from .random import Random from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
...@@ -33,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -33,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"] "clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc"]
ntask = 1 ntask = 1
rank = 0 rank = 0
...@@ -149,7 +151,3 @@ def absmax(arr): ...@@ -149,7 +151,3 @@ def absmax(arr):
def norm(arr, ord=2): def norm(arr, ord=2):
return np.linalg.norm(arr.reshape(-1), ord=ord) return np.linalg.norm(arr.reshape(-1), ord=ord)
def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
...@@ -628,11 +628,14 @@ class Field(object): ...@@ -628,11 +628,14 @@ class Field(object):
def flexible_addsub(self, other, neg): def flexible_addsub(self, other, neg):
return self-other if neg else self+other return self-other if neg else self+other
def positive_tanh(self): def sigmoid(self):
return 0.5*(1.+self.tanh()) return 0.5*(1.+self.tanh())
def clipped_exp(self): def clip(self, min=None, max=None):
return Field(self._domain, dobj.clipped_exp(self._val)) return Field(self._domain, dobj.clip(self._val, min, max))
def one_over(self):
return 1/self
def _binary_op(self, other, op): def _binary_op(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
...@@ -669,7 +672,9 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -669,7 +672,9 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]: for f in ["sqrt", "exp", "log", "tanh",
"sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]:
def func(f): def func(f):
def func2(self): def func2(self):
return Field(self._domain, getattr(dobj, f)(self.val)) return Field(self._domain, getattr(dobj, f)(self.val))
......
...@@ -29,7 +29,7 @@ from ..operators.linear_operator import LinearOperator ...@@ -29,7 +29,7 @@ from ..operators.linear_operator import LinearOperator
def _gaussian_error_function(x): def _gaussian_error_function(x):
return 0.5*erfc(x*np.sqrt(2.)) return 0.5/erfc(x*np.sqrt(2.))
def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf): def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf):
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from .field import Field from .field import Field
from .multi_field import MultiField from .multi_field import MultiField
from .sugar import makeOp from .sugar import makeOp
from .operators.scaling_operator import ScalingOperator
class Linearization(object): class Linearization(object):
...@@ -196,23 +197,71 @@ class Linearization(object): ...@@ -196,23 +197,71 @@ class Linearization(object):
tmp = self._val.exp() tmp = self._val.exp()
return self.new(tmp, makeOp(tmp)(self._jac)) return self.new(tmp, makeOp(tmp)(self._jac))
def clipped_exp(self): def clip(self, min=None, max=None):
tmp = self._val.clipped_exp() tmp = self._val.clip(min, max)
return self.new(tmp, makeOp(tmp)(self._jac)) if (min is None) and (max is None):
return self
elif max is None:
tmp2 = makeOp(1. - (tmp == min))
elif min is None:
tmp2 = makeOp(1. - (tmp == max))
else:
tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
return self.new(tmp, tmp2(self._jac))
def sin(self):
tmp = self._val.sin()
tmp2 = self._val.cos()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cos(self):
tmp = self._val.cos()
tmp2 = - self._val.sin()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tan(self):
tmp = self._val.tan()
tmp2 = 1./(self._val.cos()**2)
return self.new(tmp, makeOp(tmp2)(self._jac))
def sinc(self):
tmp = self._val.sinc()
tmp2 = (self._val.cos()-tmp)/self._val
return self.new(tmp, makeOp(tmp2)(self._jac))
def log(self): def log(self):
tmp = self._val.log() tmp = self._val.log()
return self.new(tmp, makeOp(1./self._val)(self._jac)) return self.new(tmp, makeOp(1./self._val)(self._jac))
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cosh(self):
tmp = self._val.cosh()
tmp2 = self._val.sinh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tanh(self): def tanh(self):
tmp = self._val.tanh() tmp = self._val.tanh()
return self.new(tmp, makeOp(1.-tmp**2)(self._jac)) return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
def positive_tanh(self): def sigmoid(self):
tmp = self._val.tanh() tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp) tmp2 = 0.5*(1.+tmp)
return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
tmp = self._val.absolute()
tmp2 = self._val.sign()
return self.new(tmp, makeOp(tmp2)(self._jac))
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
return self.new(tmp, makeOp(tmp2)(self._jac))
def add_metric(self, metric): def add_metric(self, metric):
return self.new(self._val, self._jac, metric) return self.new(self._val, self._jac, metric)
......
...@@ -190,6 +190,10 @@ class MultiField(object): ...@@ -190,6 +190,10 @@ class MultiField(object):
def conjugate(self): def conjugate(self):
return self._transform(lambda x: x.conjugate()) return self._transform(lambda x: x.conjugate())
def clip(self, min=None, max=None):
return MultiField(self._domain,
tuple(clip(v, min, max) for v in self._val))
def all(self): def all(self):
for v in self._val: for v in self._val:
if not v.all(): if not v.all():
...@@ -292,7 +296,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -292,7 +296,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op)) setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh", "clipped_exp"]: for f in ["sqrt", "exp", "log", "tanh"]:
def func(f): def func(f):
def func2(self): def func2(self):
fu = getattr(Field, f) fu = getattr(Field, f)
......
...@@ -94,6 +94,11 @@ class Operator(NiftyMetaBase()): ...@@ -94,6 +94,11 @@ class Operator(NiftyMetaBase()):
return NotImplemented return NotImplemented
return _OpChain.make((_PowerOp(self.target, power), self)) return _OpChain.make((_PowerOp(self.target, power), self))
def clip(self, min=None, max=None):
if min is None and max is None:
return self
return _OpChain.make((_Clipper(self.target, min, max), self))
def apply(self, x): def apply(self, x):
raise NotImplementedError raise NotImplementedError
...@@ -123,7 +128,8 @@ class Operator(NiftyMetaBase()): ...@@ -123,7 +128,8 @@ class Operator(NiftyMetaBase()):
return self.__class__.__name__ return self.__class__.__name__
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", 'clipped_exp']: for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
def func(f): def func(f):
def func2(self): def func2(self):
fa = _FunctionApplier(self.target, f) fa = _FunctionApplier(self.target, f)
...@@ -143,6 +149,18 @@ class _FunctionApplier(Operator): ...@@ -143,6 +149,18 @@ class _FunctionApplier(Operator):
return getattr(x, self._funcname)() return getattr(x, self._funcname)()
class _Clipper(Operator):
def __init__(self, domain, min=None, max=None):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._min = min
self._max = max
def apply(self, x):
self._check_input(x)
return x.clip(self._min, self._max)
class _PowerOp(Operator): class _PowerOp(Operator):
def __init__(self, domain, power): def __init__(self, domain, power):
from ..sugar import makeDomain from ..sugar import makeDomain
......
...@@ -33,7 +33,9 @@ from .operators.distributors import PowerDistributor ...@@ -33,7 +33,9 @@ from .operators.distributors import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
'full', 'from_global_data', 'from_local_data', 'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'positive_tanh', 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid',
'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'clip', 'sinc',
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union', 'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain'] 'get_default_codomain']
...@@ -253,7 +255,9 @@ def domain_union(domains): ...@@ -253,7 +255,9 @@ def domain_union(domains):
_current_module = sys.modules[__name__] _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]: for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'sinc']:
def func(f): def func(f):
def func2(x): def func2(x):
from .linearization import Linearization from .linearization import Linearization
...@@ -266,6 +270,10 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]: ...@@ -266,6 +270,10 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]:
setattr(_current_module, f, func(f)) setattr(_current_module, f, func(f))
def clip(a, a_min=None, a_max=None):
return a.clip(a_min, a_max)
def get_default_codomain(domainoid, space=None): def get_default_codomain(domainoid, space=None):
"""For `RGSpace`, returns the harmonic partner domain. """For `RGSpace`, returns the harmonic partner domain.
For `DomainTuple`, returns a copy of the object in which the domain For `DomainTuple`, returns a copy of the object in which the domain
......
...@@ -111,7 +111,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -111,7 +111,7 @@ class Energy_Tests(unittest.TestCase):
def testBernoulli(self, space, seed): def testBernoulli(self, space, seed):
model = self.make_model( model = self.make_model(
space_key='s1', space=space, seed=seed)['s1'] space_key='s1', space=space, seed=seed)['s1']
model = model.positive_tanh() model = model.sigmoid()
d = np.random.binomial(1, 0.1, size=space.shape) d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(d) energy = ift.BernoulliEnergy(d)
......
...@@ -291,3 +291,14 @@ class Test_Functionality(unittest.TestCase): ...@@ -291,3 +291,14 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.local_data.shape, ()) assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1) assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.) assert_equal(f.vdot(f), 9.)
@expand(product([float(5), 5.],
[ift.RGSpace((8,), harmonic=True), ()],
["exp", "log", "sin", "cos", "tan", "sinh", "cosh", "sinc",
"absolute", "sign"]))
def test_funcs(self, num, dom, func):
num = 5
f = ift.Field.full(dom, num)
res = getattr(f, func)()
res2 = getattr(np, func)(num)
assert_allclose(res.local_data, res2)
...@@ -75,7 +75,7 @@ class Model_Tests(unittest.TestCase): ...@@ -75,7 +75,7 @@ class Model_Tests(unittest.TestCase):
model = ift.ScalingOperator(2.456, space)(select_s1*select_s2) model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)( model = ift.sigmoid(ift.ScalingOperator(2.456, space)(
select_s1*select_s2)) select_s1*select_s2))
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment