Commit f685631a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'local_nonlinearities' into 'NIFTy_5'

added a number of local nonlinear functions

See merge request ift/nifty-dev!144
parents 4c0bd584 a17514ff
......@@ -54,7 +54,7 @@ if __name__ == '__main__':
A = ift.create_power_operator(harmonic_space, sqrtpspec)
# Set up a sky model and instrumental response
sky = ift.positive_tanh(HT(A))
sky = ift.sigmoid(HT(A))
GR = ift.GeometryRemover(position_space)
R = GR
......@@ -76,7 +76,7 @@ if __name__ == '__main__':
# correlated_field = ift.CorrelatedField(position_space, A)
# Apply a nonlinearity
signal = ift.positive_tanh(correlated_field)
signal = ift.sigmoid(correlated_field)
# Build the line-of-sight response and define signal response
LOS_starts, LOS_ends = random_los(100) if mode == 1 else radial_los(100)
......@@ -136,7 +136,7 @@ A :class:`Field` object consists of the following components:
- a data type (e.g. numpy.float64)
- an array containing the actual values
Usually, the array is stored in the for of a ``numpy.ndarray``, but for very
Usually, the array is stored in the form of a ``numpy.ndarray``, but for very
resource-intensive tasks NIFTy also provides an alternative storage method to
be used with distributed memory processing.
......@@ -31,7 +31,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "clip"]
ntask = _comm.Get_size()
......@@ -211,6 +212,9 @@ class data_object(object):
return data_object(self._shape, tval, self._distaxis)
def clip(self, min=None, max=None):
return data_object(self._shape, np.clip(self._data, min, max))
def __neg__(self):
return data_object(self._shape, -self._data, self._distaxis)
......@@ -292,7 +296,8 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
......@@ -300,8 +305,8 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f))
def clipped_exp(a):
return data_object(x.shape, np.exp(np.clip(, -300, 300), x.distaxis))
def clip(x, a_min=None, a_max=None):
return data_object(x.shape, np.clip(, a_min, a_max), x.distaxis)
def from_object(object, dtype, copy, set_locked):
......@@ -21,7 +21,8 @@ import numpy as np
from numpy import empty, empty_like, exp, full, log
from numpy import ndarray as data_object
from numpy import ones, sqrt, tanh, vdot, zeros
from numpy import sin, cos, tan, sinh, cosh, sinc
from numpy import absolute, sign, clip
from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
......@@ -33,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc"]
ntask = 1
rank = 0
......@@ -149,7 +151,3 @@ def absmax(arr):
def norm(arr, ord=2):
return np.linalg.norm(arr.reshape(-1), ord=ord)
def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
......@@ -628,11 +628,14 @@ class Field(object):
def flexible_addsub(self, other, neg):
return self-other if neg else self+other
def positive_tanh(self):
def sigmoid(self):
return 0.5*(1.+self.tanh())
def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val))
def clip(self, min=None, max=None):
return Field(self._domain, dobj.clip(self._val, min, max))
def one_over(self):
return 1/self
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
......@@ -669,7 +672,9 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
for f in ["sqrt", "exp", "log", "tanh",
"sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]:
def func(f):
def func2(self):
return Field(self._domain, getattr(dobj, f)(self.val))
......@@ -29,7 +29,7 @@ from ..operators.linear_operator import LinearOperator
def _gaussian_error_function(x):
return 0.5*erfc(x*np.sqrt(2.))
return 0.5/erfc(x*np.sqrt(2.))
def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf):
......@@ -20,6 +20,7 @@ import numpy as np
from .field import Field
from .multi_field import MultiField
from .sugar import makeOp
from .operators.scaling_operator import ScalingOperator
class Linearization(object):
......@@ -196,23 +197,71 @@ class Linearization(object):
tmp = self._val.exp()
return, makeOp(tmp)(self._jac))
def clipped_exp(self):
tmp = self._val.clipped_exp()
return, makeOp(tmp)(self._jac))
def clip(self, min=None, max=None):
tmp = self._val.clip(min, max)
if (min is None) and (max is None):
return self
elif max is None:
tmp2 = makeOp(1. - (tmp == min))
elif min is None:
tmp2 = makeOp(1. - (tmp == max))
tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
return, tmp2(self._jac))
def sin(self):
tmp = self._val.sin()
tmp2 = self._val.cos()
return, makeOp(tmp2)(self._jac))
def cos(self):
tmp = self._val.cos()
tmp2 = - self._val.sin()
return, makeOp(tmp2)(self._jac))
def tan(self):
tmp = self._val.tan()
tmp2 = 1./(self._val.cos()**2)
return, makeOp(tmp2)(self._jac))
def sinc(self):
tmp = self._val.sinc()
tmp2 = (self._val.cos()-tmp)/self._val
return, makeOp(tmp2)(self._jac))
def log(self):
tmp = self._val.log()
return, makeOp(1./self._val)(self._jac))
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
return, makeOp(tmp2)(self._jac))
def cosh(self):
tmp = self._val.cosh()
tmp2 = self._val.sinh()
return, makeOp(tmp2)(self._jac))
def tanh(self):
tmp = self._val.tanh()
return, makeOp(1.-tmp**2)(self._jac))
def positive_tanh(self):
def sigmoid(self):
tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp)
return, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
tmp = self._val.absolute()
tmp2 = self._val.sign()
return, makeOp(tmp2)(self._jac))
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
return, makeOp(tmp2)(self._jac))
def add_metric(self, metric):
return, self._jac, metric)
......@@ -190,6 +190,10 @@ class MultiField(object):
def conjugate(self):
return self._transform(lambda x: x.conjugate())
def clip(self, min=None, max=None):
return MultiField(self._domain,
tuple(clip(v, min, max) for v in self._val))
def all(self):
for v in self._val:
if not v.all():
......@@ -292,7 +296,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh", "clipped_exp"]:
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
......@@ -94,6 +94,11 @@ class Operator(NiftyMetaBase()):
return NotImplemented
return _OpChain.make((_PowerOp(, power), self))
def clip(self, min=None, max=None):
if min is None and max is None:
return self
return _OpChain.make((_Clipper(, min, max), self))
def apply(self, x):
raise NotImplementedError
......@@ -123,7 +128,8 @@ class Operator(NiftyMetaBase()):
return self.__class__.__name__
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", 'clipped_exp']:
for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
def func(f):
def func2(self):
fa = _FunctionApplier(, f)
......@@ -143,6 +149,18 @@ class _FunctionApplier(Operator):
return getattr(x, self._funcname)()
class _Clipper(Operator):
def __init__(self, domain, min=None, max=None):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._min = min
self._max = max
def apply(self, x):
return x.clip(self._min, self._max)
class _PowerOp(Operator):
def __init__(self, domain, power):
from ..sugar import makeDomain
......@@ -33,7 +33,9 @@ from .operators.distributors import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'positive_tanh',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid',
'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'clip', 'sinc',
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
......@@ -253,7 +255,9 @@ def domain_union(domains):
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]:
for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'sinc']:
def func(f):
def func2(x):
from .linearization import Linearization
......@@ -266,6 +270,10 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]:
setattr(_current_module, f, func(f))
def clip(a, a_min=None, a_max=None):
return a.clip(a_min, a_max)
def get_default_codomain(domainoid, space=None):
"""For `RGSpace`, returns the harmonic partner domain.
For `DomainTuple`, returns a copy of the object in which the domain
......@@ -111,7 +111,7 @@ class Energy_Tests(unittest.TestCase):
def testBernoulli(self, space, seed):
model = self.make_model(
space_key='s1', space=space, seed=seed)['s1']
model = model.positive_tanh()
model = model.sigmoid()
d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(d)
......@@ -291,3 +291,14 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.)
@expand(product([float(5), 5.],
[ift.RGSpace((8,), harmonic=True), ()],
["exp", "log", "sin", "cos", "tan", "sinh", "cosh", "sinc",
"absolute", "sign"]))
def test_funcs(self, num, dom, func):
num = 5
f = ift.Field.full(dom, num)
res = getattr(f, func)()
res2 = getattr(np, func)(num)
assert_allclose(res.local_data, res2)
......@@ -75,7 +75,7 @@ class Model_Tests(unittest.TestCase):
model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
model = ift.sigmoid(ift.ScalingOperator(2.456, space)(
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment