Commit 4424bdee authored by theos's avatar theos

Added inplace parameter to Space.weight.

parent 39b1cc50
......@@ -220,7 +220,10 @@ class Field(object):
self.set_val(new_val=val, copy=copy)
def _infer_dtype(self, domain=None, dtype=None, field_type=None):
dtype_tuple = (np.dtype(gc['default_field_dtype']),)
if dtype is None:
dtype_tuple = (np.dtype(gc['default_field_dtype']),)
else:
dtype_tuple = (np.dtype(dtype))
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
......@@ -331,6 +334,7 @@ class Field(object):
def copy(self, domain=None, codomain=None, field_type=None, **kwargs):
copied_val = self._unary_operation(self.get_val(), op='copy', **kwargs)
# TODO: respect distribution_strategy
new_field = self.copy_empty(domain=domain,
codomain=codomain,
field_type=field_type)
......@@ -391,6 +395,7 @@ class Field(object):
**kwargs)
return new_field
# TODO: use property for val
def set_val(self, new_val=None, copy=False):
"""
Resets the field values.
......@@ -431,6 +436,7 @@ class Field(object):
return global_shape
# use space.dim and field_type.dim
@property
def dim(self):
"""
......@@ -512,6 +518,7 @@ class Field(object):
shape = self.shape
# Case 1: x is a distributed_data_object
# TODO: Use d2o casting for this case directly, too.
if isinstance(x, distributed_data_object):
if x.comm is not self._comm:
raise ValueError(about._errors.cstring(
......@@ -608,9 +615,10 @@ class Field(object):
spaces = range(len(self.shape))
for ind, sp in enumerate(self.domain):
new_val = sp.calc_weight(new_val,
power=power,
axes=self.domain_axes[ind])
new_val = sp.weight(new_val,
power=power,
axes=self.domain_axes[ind],
inplace=inplace)
new_field.set_val(new_val=new_val, copy=False)
return new_field
......@@ -1164,6 +1172,7 @@ class Field(object):
return self._unary_operation(self.get_val(), op='var',
**kwargs)
# TODO: replace `split` by `def argmin_nonflat`
def argmin(self, split=False, **kwargs):
"""
Returns the index of the minimum field value.
......@@ -1348,7 +1357,8 @@ class Field(object):
def __add__(self, other):
return self._binary_helper(other, op='add')
__radd__ = __add__
def __radd__(self, other):
return self._binary_helper(other, op='radd')
def __iadd__(self, other):
return self._binary_helper(other, op='iadd', inplace=True)
......@@ -1365,7 +1375,8 @@ class Field(object):
def __mul__(self, other):
return self._binary_helper(other, op='mul')
__rmul__ = __mul__
def __rmul__(self, other):
return self._binary_helper(other, op='rmul')
def __imul__(self, other):
return self._binary_helper(other, op='imul', inplace=True)
......@@ -1379,9 +1390,6 @@ class Field(object):
def __idiv__(self, other):
return self._binary_helper(other, op='idiv', inplace=True)
__truediv__ = __div__
__itruediv__ = __idiv__
def __pow__(self, other):
return self._binary_helper(other, op='pow')
......
......@@ -35,7 +35,7 @@ class PowerSpace(Space):
# every power-pixel has a volume of 1
return reduce(lambda x, y: x*y, self.paradict['pindex'].shape)
def weight(self, x, power=1, axes=None):
def weight(self, x, power=1, axes=None, inplace=False):
total_shape = x.shape
axes = cast_axis_to_tuple(axes, len(total_shape))
......@@ -49,7 +49,12 @@ class PowerSpace(Space):
weight = self.paradict['rho'].reshape(reshaper)
if power != 1:
weight = weight ** power
result_x = x * weight
if inplace:
x *= weight
result_x = x
else:
result_x = x*weight
return result_x
......
......@@ -175,9 +175,14 @@ class RGSpace(Space):
def total_volume(self):
return self.dim * reduce(lambda x, y: x*y, self.paradict['distances'])
def weight(self, x, power=1, axes=None):
def weight(self, x, power=1, axes=None, inplace=False):
weight = reduce(lambda x, y: x*y, self.paradict['distances'])**power
return x * weight
if inplace:
x *= weight
result_x = x
else:
result_x = x*weight
return result_x
def compute_k_array(self, distribution_strategy):
"""
......
......@@ -262,7 +262,7 @@ class Space(object):
def complement_cast(self, x, axes=None):
return x
def weight(self, x, power=1, axes=None):
def weight(self, x, power=1, axes=None, inplace=False):
"""
Weights a given array of field values with the pixel volumes (not
the meta volumes) to a given power.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment