Commit 871c8d12 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: SmoothOperator for RGSpace

parent 30e5db9f
......@@ -27,6 +27,8 @@ from diagonal_operator import DiagonalOperator
from endomorphic_operator import EndomorphicOperator
from smooth_operator import SmoothOperator
from fft_operator import *
from nifty_operators import operator,\
......@@ -48,4 +50,4 @@ from nifty_probing import prober,\
inverse_diagonal_prober
from nifty_minimization import conjugate_gradient,\
steepest_descent
\ No newline at end of file
steepest_descent
......@@ -9,52 +9,44 @@ from nifty.operators.fft_operator import FFTOperator
class SmoothOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), inplace=False,
sigma = None, implemented=False):
def __init__(self, domain=(), field_type=(), inplace=False, sigma=None):
super(SmoothOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
field_type=field_type)
if len(self.domain) != 1:
raise ValueError(
about._errors.cstring(
'ERROR: SmoothOperator accepts only exactly one '
'space as input domain.')
)
if self.field_type != ():
raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator field-type must be an '
'ERROR: SmoothOperator field-type must be an '
'empty tuple.'
))
self._sigma = sigma
self._inplace = inplace
self._implemented = bool(implemented)
self._inplace = bool(inplace)
def _inverse_times(self, x, spaces, types):
return self._smooth_helper(x, spaces, types, inverse=True)
def _times(self, x, spaces, types):
if sigma == 0:
return x if self.inplace else x.copy()
return self._smooth_helper(x, spaces, types)
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
# ---Mandatory properties and methods---
@property
def implemented(self):
return True
if spaces is None:
return x if self.inplace else x.copy()
@property
def symmetric(self):
return False
for space in spaces:
axes = x.domain_axes[space]
for space_axis, val_axis in zip(
range(len(x.domain[space].shape)), axes):
transform = FFTOperator(x.domain[space])
kernel = x.domain[space].get_codomain_mask(
self.sigma, space_axis)
if isinstance(x.domain[space], RGSpace):
new_shape = np.ones(len(x.shape), dtype=np.int)
new_shape[val_axis] = len(kernel)
kernel = kernel.reshape(new_shape)
# transform
transformed_inp = transform(x)
transformed_inp *= kernel
elif isinstance(x.domain[space], LMSpace):
pass
else:
raise ValueError(about._errors.cstring(
'ERROR: SmoothOperator cannot smooth space ' +
str(x.domain[space]))
@property
def unitary(self):
return False
# ---Added properties and methods---
@property
......@@ -64,3 +56,53 @@ class SmoothOperator(EndomorphicOperator):
@property
def inplace(self):
return self._inplace
def _smooth_helper(self, x, spaces, types, inverse=False):
if self.sigma == 0:
return x if self.inplace else x.copy()
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
return x if self.inplace else x.copy()
# copy for doing the actual smoothing
smooth_out = x.copy()
space_obj = x.domain[spaces[0]]
axes = x.domain_axes[spaces[0]]
for space_axis, val_axis in zip(range(len(space_obj.shape)), axes):
transform = FFTOperator(space_obj)
kernel = space_obj.get_codomain_smoothing_kernel(
self.sigma, space_axis
)
if isinstance(space_obj, RGSpace):
new_shape = np.ones(len(x.shape), dtype=np.int)
new_shape[val_axis] = len(kernel)
kernel = kernel.reshape(new_shape)
# transform
smooth_out = transform(smooth_out, spaces=spaces[0])
# multiply kernel
if inverse:
smooth_out.val /= kernel
else:
smooth_out.val *= kernel
# inverse transform
smooth_out = transform.inverse_times(smooth_out,
spaces=spaces[0])
elif isinstance(space_obj, LMSpace):
pass
else:
raise ValueError(about._errors.cstring(
'ERROR: SmoothOperator cannot smooth space ' +
str(space_obj)))
if self.inplace:
x.set_val(val=smooth_out.val)
return x
else:
return smooth_out
......@@ -310,10 +310,14 @@ class RGSpace(Space):
temp[:] = zerocenter
return tuple(temp)
def get_codomain_mask(self, sigma, axis):
def get_codomain_smoothing_kernel(self, sigma, axis):
if sigma is None:
sigma = np.sqrt(2) * np.max(self.distances)
mask = np.fft.fftfreq(self.shape[axis], d=self.distances[axis])
gaussian = lambda x: np.exp(-2. * np.pi**2 * x**2 * sigma**2)
k = np.fft.fftfreq(self.shape[axis], d=self.distances[axis])
return mask if self.zerocenter[axis] else np.fft.fftshift(mask)
if self.zerocenter[axis]:
k = np.fft.fftshift(k)
return np.array(gaussian(k))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment