diff --git a/nifty/field.py b/nifty/field.py index 7f1dfb83a89d75abdff37d50517a7ea2ba4a8b70..92f96508000af7d023caba5adff5a050f2fefe86 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -490,15 +490,23 @@ class Field(object): else: dtype = np.dtype(dtype) - casted_x = self._actual_cast(x, dtype=dtype) + for ind, sp in enumerate(self.domain): + casted_x = sp.pre_cast(x, + axes=self.domain_axes[ind]) + + for ind, ft in enumerate(self.field_type): + casted_x = ft.pre_cast(casted_x, + axes=self.field_type_axes[ind]) + + casted_x = self._actual_cast(casted_x, dtype=dtype) for ind, sp in enumerate(self.domain): - casted_x = sp.complement_cast(casted_x, - axes=self.domain_axes[ind]) + casted_x = sp.post_cast(casted_x, + axes=self.domain_axes[ind]) for ind, ft in enumerate(self.field_type): - casted_x = ft.complement_cast(casted_x, - axes=self.field_type_axes[ind]) + casted_x = ft.post_cast(casted_x, + axes=self.field_type_axes[ind]) return casted_x diff --git a/nifty/field_types/field_type.py b/nifty/field_types/field_type.py index b4d838fc78afe422f77f08732c1f67777131a9fc..1810f77aa141dc86e69a9f68960e6fba10468aa7 100644 --- a/nifty/field_types/field_type.py +++ b/nifty/field_types/field_type.py @@ -51,5 +51,8 @@ class FieldType(object): return result_array - def complement_cast(self, x, axes=None): + def pre_cast(self, x, axes=None): + return x + + def post_cast(self, x, axes=None): return x diff --git a/nifty/spaces/power_space/power_space.py b/nifty/spaces/power_space/power_space.py index 4eb0b7b619f63b4cce0d84abbc8255a21d42c458..3be5711892a85059edee441729deb17a4896ba3f 100644 --- a/nifty/spaces/power_space/power_space.py +++ b/nifty/spaces/power_space/power_space.py @@ -14,7 +14,8 @@ class PowerSpace(Space): # ---Overwritten properties and methods--- - def __init__(self, harmonic_domain=RGSpace((1,)), distribution_strategy='not', + def __init__(self, harmonic_domain=RGSpace((1,)), + distribution_strategy='not', log=False, nbin=None, binbounds=None, dtype=np.dtype('float')): @@ -51,6 +52,12 @@ class PowerSpace(Space): raise NotImplementedError(about._errors.cstring( "ERROR: There is no k_array implementation for PowerSpace.")) + def pre_cast(self, x, axes=None): + if callable(x): + return x(self.kindex) + else: + return x + # ---Mandatory properties and methods--- @property diff --git a/nifty/spaces/space/space.py b/nifty/spaces/space/space.py index ca72a05058b0cb730c881fa4d6ae13978be92a0a..ebdc77ba3e695ca93b718ff196b555a1c7de10cd 100644 --- a/nifty/spaces/space/space.py +++ b/nifty/spaces/space/space.py @@ -266,7 +266,10 @@ class Space(object): """ raise NotImplementedError - def complement_cast(self, x, axes=None): + def pre_cast(self, x, axes=None): + return x + + def post_cast(self, x, axes=None): return x def compute_k_array(self, distribution_strategy):