Commit 9e644972 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'improving_clip' into 'NIFTy_5'

now clipper takes fields or multifields

See merge request !293
parents d8bf9bf8 aa4e0f38
Pipeline #43900 passed with stages
in 18 minutes and 59 seconds
...@@ -636,6 +636,8 @@ class Field(object): ...@@ -636,6 +636,8 @@ class Field(object):
return 0.5*(1.+self.tanh()) return 0.5*(1.+self.tanh())
def clip(self, min=None, max=None): def clip(self, min=None, max=None):
min = min.local_data if isinstance(min, Field) else min
max = max.local_data if isinstance(max, Field) else max
return Field(self._domain, dobj.clip(self._val, min, max)) return Field(self._domain, dobj.clip(self._val, min, max))
def one_over(self): def one_over(self):
......
...@@ -192,8 +192,12 @@ class MultiField(object): ...@@ -192,8 +192,12 @@ class MultiField(object):
return self._transform(lambda x: x.conjugate()) return self._transform(lambda x: x.conjugate())
def clip(self, min=None, max=None): def clip(self, min=None, max=None):
return MultiField(self._domain, ncomp = len(self._val)
tuple(clip(v, min, max) for v in self._val)) lmin = min._val if isinstance(min, MultiField) else (min,)*ncomp
lmax = max._val if isinstance(max, MultiField) else (max,)*ncomp
return MultiField(
self._domain,
tuple(self._val[i].clip(lmin[i], lmax[i]) for i in range(ncomp)))
def all(self): def all(self):
for v in self._val: for v in self._val:
......
...@@ -87,6 +87,10 @@ def testBinary(type1, type2, space, seed): ...@@ -87,6 +87,10 @@ def testBinary(type1, type2, space, seed):
model = select_s1.clip(-1, 1) model = select_s1.clip(-1, 1)
pos = ift.from_random("normal", dom1) pos = ift.from_random("normal", dom1)
ift.extra.check_jacobian_consistency(model, pos, ntries=20) ift.extra.check_jacobian_consistency(model, pos, ntries=20)
f = ift.from_random("normal", space)
model = select_s1.clip(f-0.1, f+1.)
pos = ift.from_random("normal", dom1)
ift.extra.check_jacobian_consistency(model, pos, ntries=20)
if isinstance(space, ift.RGSpace): if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)(select_s1*select_s2) model = ift.FFTOperator(space)(select_s1*select_s2)
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment