Skip to content
Snippets Groups Projects
Commit 5238f400 authored by theos's avatar theos
Browse files

Added keyword arguments mean and std to Field.power_synthesize. Bug fixes in...

Added keyword arguments mean and std to Field.power_synthesize. Bug fixes in Field, SmoothingOperator, LMSpace and Space.
parent 320225fe
No related branches found
No related tags found
No related merge requests found
...@@ -55,6 +55,6 @@ from operators import * ...@@ -55,6 +55,6 @@ from operators import *
from probing import * from probing import *
from demos import get_demo_dir from sugar import *
#import pyximport; pyximport.install(pyimport = True) #import pyximport; pyximport.install(pyimport = True)
...@@ -309,7 +309,8 @@ class Field(object): ...@@ -309,7 +309,8 @@ class Field(object):
return result_obj return result_obj
def power_synthesize(self, spaces=None, real_signal=True): def power_synthesize(self, spaces=None, real_signal=True,
mean=None, std=None):
# assert that all spaces in `self.domain` are either of signal-type or # assert that all spaces in `self.domain` are either of signal-type or
# power_space instances # power_space instances
for sp in self.domain: for sp in self.domain:
...@@ -356,7 +357,9 @@ class Field(object): ...@@ -356,7 +357,9 @@ class Field(object):
result_list = [self.__class__.from_random( result_list = [self.__class__.from_random(
'normal', 'normal',
result_domain, mean=mean,
std=std,
domain=result_domain,
dtype=harmonic_domain.dtype, dtype=harmonic_domain.dtype,
field_type=self.field_type, field_type=self.field_type,
distribution_strategy=self.distribution_strategy) distribution_strategy=self.distribution_strategy)
...@@ -489,8 +492,10 @@ class Field(object): ...@@ -489,8 +492,10 @@ class Field(object):
else: else:
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
casted_x = x
for ind, sp in enumerate(self.domain): for ind, sp in enumerate(self.domain):
casted_x = sp.pre_cast(x, casted_x = sp.pre_cast(casted_x,
axes=self.domain_axes[ind]) axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type): for ind, ft in enumerate(self.field_type):
......
...@@ -9,7 +9,7 @@ from nifty.operators.fft_operator import FFTOperator ...@@ -9,7 +9,7 @@ from nifty.operators.fft_operator import FFTOperator
class SmoothingOperator(EndomorphicOperator): class SmoothingOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), sigma=None): def __init__(self, domain=(), field_type=(), sigma=0):
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type) self._field_type = self._parse_field_type(field_type)
...@@ -86,7 +86,7 @@ class SmoothingOperator(EndomorphicOperator): ...@@ -86,7 +86,7 @@ class SmoothingOperator(EndomorphicOperator):
axes_local_distribution_strategy = \ axes_local_distribution_strategy = \
transformed_x.val.get_axes_local_distribution_strategy(axes=coaxes) transformed_x.val.get_axes_local_distribution_strategy(axes=coaxes)
kernel = codomain.distance_array( kernel = codomain.get_distance_array(
distribution_strategy=axes_local_distribution_strategy) distribution_strategy=axes_local_distribution_strategy)
kernel.apply_scalar_function( kernel.apply_scalar_function(
codomain.get_smoothing_kernel_function(self.sigma), codomain.get_smoothing_kernel_function(self.sigma),
......
...@@ -114,22 +114,6 @@ class LMSpace(Space): ...@@ -114,22 +114,6 @@ class LMSpace(Space):
super(LMSpace, self).__init__(dtype) super(LMSpace, self).__init__(dtype)
self._lmax = self._parse_lmax(lmax) self._lmax = self._parse_lmax(lmax)
def distance_array(self, distribution_strategy):
dists = arange(start=0, stop=self.shape[0],
distribution_strategy=distribution_strategy)
dists = dists.apply_scalar_function(
lambda x: _distance_array_helper(x, self.lmax),
dtype=np.float)
return dists
def get_smoothing_kernel_function(self, sigma):
if sigma is None:
sigma = np.sqrt(2) * np.pi / (self.lmax + 1)
return lambda x: np.exp(-0.5 * x * (x + 1) * sigma**2)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property @property
...@@ -165,6 +149,22 @@ class LMSpace(Space): ...@@ -165,6 +149,22 @@ class LMSpace(Space):
else: else:
return x.copy() return x.copy()
def get_distance_array(self, distribution_strategy):
dists = arange(start=0, stop=self.shape[0],
distribution_strategy=distribution_strategy)
dists = dists.apply_scalar_function(
lambda x: _distance_array_helper(x, self.lmax),
dtype=np.float)
return dists
def get_smoothing_kernel_function(self, sigma):
if sigma is None:
sigma = np.sqrt(2) * np.pi / (self.lmax + 1)
return lambda x: np.exp(-0.5 * x * (x + 1) * sigma**2)
# ---Added properties and methods--- # ---Added properties and methods---
@property @property
......
...@@ -272,9 +272,13 @@ class Space(object): ...@@ -272,9 +272,13 @@ class Space(object):
def post_cast(self, x, axes=None): def post_cast(self, x, axes=None):
return x return x
def compute_k_array(self, distribution_strategy): def get_distance_array(self, distribution_strategy):
raise NotImplementedError( raise NotImplementedError(
"There is no generic k_array for Space base class.") "There is no generic distance structure for Space base class.")
def get_smoothing_kernel_function(self, sigma):
raise NotImplementedError(
"There is no generic co-smoothing kernel for Space base class.")
def hermitian_decomposition(self, x, axes=None): def hermitian_decomposition(self, x, axes=None):
raise NotImplementedError raise NotImplementedError
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment