Commit 7f8c4fdf authored by Reimar Heinrich Leike's avatar Reimar Heinrich Leike

now clipper takes fields or multifields

parent eba1fd40
Pipeline #43679 passed with stages
in 8 minutes and 6 seconds
...@@ -636,6 +636,10 @@ class Field(object): ...@@ -636,6 +636,10 @@ class Field(object):
return 0.5*(1.+self.tanh()) return 0.5*(1.+self.tanh())
def clip(self, min=None, max=None): def clip(self, min=None, max=None):
if isinstance(min, Field):
min = min.local_data
if isinstance(max, Field):
max = max.local_data
return Field(self._domain, dobj.clip(self._val, min, max)) return Field(self._domain, dobj.clip(self._val, min, max))
def one_over(self): def one_over(self):
......
...@@ -192,8 +192,18 @@ class MultiField(object): ...@@ -192,8 +192,18 @@ class MultiField(object):
return self._transform(lambda x: x.conjugate()) return self._transform(lambda x: x.conjugate())
def clip(self, min=None, max=None): def clip(self, min=None, max=None):
return MultiField(self._domain, fields = []
tuple(clip(v, min, max) for v in self._val)) for i in range(len(self._val)):
if isinstance(min, MultiField):
this_min = min._val[i]
else:
this_min = min
if isinstance(max, MultiField):
this_max = max._val[i]
else:
this_max = max
fields += [self._val[i].clip(this_min, this_max)]
return MultiField(self._domain, tuple(fields))
def all(self): def all(self):
for v in self._val: for v in self._val:
......
...@@ -87,6 +87,10 @@ def testBinary(type1, type2, space, seed): ...@@ -87,6 +87,10 @@ def testBinary(type1, type2, space, seed):
model = select_s1.clip(-1, 1) model = select_s1.clip(-1, 1)
pos = ift.from_random("normal", dom1) pos = ift.from_random("normal", dom1)
ift.extra.check_jacobian_consistency(model, pos, ntries=20) ift.extra.check_jacobian_consistency(model, pos, ntries=20)
f = ift.from_random("normal", space)
model = select_s1.clip(f-0.1, f+1.)
pos = ift.from_random("normal", dom1)
ift.extra.check_jacobian_consistency(model, pos, ntries=20)
if isinstance(space, ift.RGSpace): if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)(select_s1*select_s2) model = ift.FFTOperator(space)(select_s1*select_s2)
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment