Commit 618cf01e authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Refactor SmoothOperator

- Refactor _smooth_helper
- Remove inplace property
- Rename method compute_k_array to distance_array in RGSpace
- Fix compute_k_array usage
parent 3a42a026
...@@ -4,8 +4,7 @@ import numpy as np ...@@ -4,8 +4,7 @@ import numpy as np
from d2o import distributed_data_object,\ from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\ from nifty.config import about, nifty_configuration as gc
nifty_configuration as gc,\
from nifty.field_types import FieldType from nifty.field_types import FieldType
......
...@@ -9,7 +9,7 @@ from nifty.operators.fft_operator import FFTOperator ...@@ -9,7 +9,7 @@ from nifty.operators.fft_operator import FFTOperator
class SmoothOperator(EndomorphicOperator): class SmoothOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), inplace=False, sigma=None): def __init__(self, domain=(), field_type=(), sigma=None):
super(SmoothOperator, self).__init__(domain=domain, super(SmoothOperator, self).__init__(domain=domain,
field_type=field_type) field_type=field_type)
...@@ -27,7 +27,6 @@ class SmoothOperator(EndomorphicOperator): ...@@ -27,7 +27,6 @@ class SmoothOperator(EndomorphicOperator):
)) ))
self._sigma = sigma self._sigma = sigma
self._inplace = bool(inplace)
def _inverse_times(self, x, spaces, types): def _inverse_times(self, x, spaces, types):
return self._smooth_helper(x, spaces, types, inverse=True) return self._smooth_helper(x, spaces, types, inverse=True)
...@@ -53,56 +52,46 @@ class SmoothOperator(EndomorphicOperator): ...@@ -53,56 +52,46 @@ class SmoothOperator(EndomorphicOperator):
def sigma(self): def sigma(self):
return self._sigma return self._sigma
@property
def inplace(self):
return self._inplace
def _smooth_helper(self, x, spaces, types, inverse=False): def _smooth_helper(self, x, spaces, types, inverse=False):
if self.sigma == 0:
return x if self.inplace else x.copy()
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
return x if self.inplace else x.copy()
# copy for doing the actual smoothing # copy for doing the actual smoothing
smooth_out = x.copy() smooth_out = x.copy()
space_obj = x.domain[spaces[0]] if spaces is not None and self.sigma != 0:
axes = x.domain_axes[spaces[0]] spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
for space_axis, val_axis in zip(range(len(space_obj.shape)), axes):
space_obj = x.domain[spaces[0]]
axes = x.domain_axes[spaces[0]]
transform = FFTOperator(space_obj) transform = FFTOperator(space_obj)
kernel = space_obj.get_codomain_smoothing_kernel(
self.sigma, space_axis
)
if isinstance(space_obj, RGSpace): # create the kernel
new_shape = np.ones(len(x.shape), dtype=np.int) kernel = space_obj.distance_array(
new_shape[val_axis] = len(kernel) x.val.get_axes_local_distribution_strategy(axes=axes))
kernel = kernel.reshape(new_shape) kernel = kernel.apply_scalar_function(
space_obj.get_codomain_smoothing_function(self.sigma))
# transform
smooth_out = transform(smooth_out, spaces=spaces[0]) # transform
smooth_out = transform(smooth_out, spaces=spaces[0])
# multiply kernel
if inverse: # local data
smooth_out.val /= kernel local_val = smooth_out.val.get_local_data(copy=False)
else:
smooth_out.val *= kernel # extract local kernel and reshape
local_kernel = kernel.get_local_data(copy=False)
# inverse transform new_shape = np.ones(len(local_val.shape), dtype=np.int)
smooth_out = transform.inverse_times(smooth_out, for space_axis, val_axis in zip(range(len(space_obj.shape)), axes):
spaces=spaces[0]) new_shape[val_axis] = local_kernel.shape[space_axis]
elif isinstance(space_obj, LMSpace): local_kernel = local_kernel.reshape(new_shape)
pass
# multiply kernel
if inverse:
local_val /= kernel
else: else:
raise ValueError(about._errors.cstring( local_val *= kernel
'ERROR: SmoothOperator cannot smooth space ' +
str(space_obj))) smooth_out.val.set_local_data(local_val, copy=False)
if self.inplace: # inverse transform
x.set_val(val=smooth_out.val) smooth_out = transform.inverse_times(smooth_out, spaces=spaces[0])
return x
else: return smooth_out
return smooth_out
...@@ -45,7 +45,7 @@ class PowerIndices(object): ...@@ -45,7 +45,7 @@ class PowerIndices(object):
self.distribution_strategy = distribution_strategy self.distribution_strategy = distribution_strategy
# Compute the global k_array # Compute the global k_array
self.k_array = self.domain.compute_k_array(distribution_strategy) self.k_array = self.domain.distance_array(distribution_strategy)
# Initialize the dictonary which stores all individual index-dicts # Initialize the dictonary which stores all individual index-dicts
self.global_dict = {} self.global_dict = {}
# Set self.default_parameters # Set self.default_parameters
......
...@@ -152,7 +152,7 @@ class RGSpace(Space): ...@@ -152,7 +152,7 @@ class RGSpace(Space):
self._distances = self._parse_distances(distances) self._distances = self._parse_distances(distances)
self._zerocenter = self._parse_zerocenter(zerocenter) self._zerocenter = self._parse_zerocenter(zerocenter)
def compute_k_array(self, distribution_strategy): def distance_array(self, distribution_strategy):
""" """
Calculates an n-dimensional array with its entries being the Calculates an n-dimensional array with its entries being the
lengths of the k-vectors from the zero point of the grid. lengths of the k-vectors from the zero point of the grid.
...@@ -181,12 +181,12 @@ class RGSpace(Space): ...@@ -181,12 +181,12 @@ class RGSpace(Space):
else: else:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: Unsupported distribution strategy")) "ERROR: Unsupported distribution strategy"))
dists = self._compute_k_array_helper(slice_of_first_dimension) dists = self._distance_array_helper(slice_of_first_dimension)
nkdict.set_local_data(dists) nkdict.set_local_data(dists)
return nkdict return nkdict
def _compute_k_array_helper(self, slice_of_first_dimension): def _distance_array_helper(self, slice_of_first_dimension):
dk = self.distances dk = self.distances
shape = self.shape shape = self.shape
...@@ -317,14 +317,8 @@ class RGSpace(Space): ...@@ -317,14 +317,8 @@ class RGSpace(Space):
temp[:] = zerocenter temp[:] = zerocenter
return tuple(temp) return tuple(temp)
def get_codomain_smoothing_kernel(self, sigma, axis): def get_codomain_smoothing_function(self, sigma):
if sigma is None: if sigma is None:
sigma = np.sqrt(2) * np.max(self.distances) sigma = np.sqrt(2) * np.max(self.distances)
gaussian = lambda x: np.exp(-2. * np.pi**2 * x**2 * sigma**2) return lambda x: np.exp(-2. * np.pi**2 * x**2 * sigma**2)
k = np.fft.fftfreq(self.shape[axis], d=self.distances[axis])
if self.zerocenter[axis]:
k = np.fft.fftshift(k)
return np.array(gaussian(k))
...@@ -269,9 +269,9 @@ class Space(object): ...@@ -269,9 +269,9 @@ class Space(object):
def complement_cast(self, x, axes=None): def complement_cast(self, x, axes=None):
return x return x
def compute_k_array(self, distribution_strategy): def distance_array(self, distribution_strategy):
raise NotImplementedError(about._errors.cstring( raise NotImplementedError(about._errors.cstring(
"ERROR: There is no generic k_array for Space base class.")) "ERROR: There is no generic distance_array for Space base class."))
def hermitian_decomposition(self, x, axes=None): def hermitian_decomposition(self, x, axes=None):
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment