Commit f9ce2a6e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

great Field revamp, next part

parent 77ee7dcc
......@@ -112,15 +112,13 @@ class Field(object):
"DomainObject instance.")
return domain
def _get_axes_tuple(self, things_with_shape, start=0):
i = start
def _get_axes_tuple(self, things_with_shape):
i = 0
axes_list = []
for thing in things_with_shape:
l = []
for j in range(len(thing.shape)):
l += [i]
i += 1
axes_list += [tuple(l)]
nax = len(thing.shape)
axes_list += [tuple(range(i,i+nax))]
i += nax
return tuple(axes_list)
def _infer_dtype(self, dtype, val):
......@@ -179,7 +177,7 @@ class Field(object):
sample = f.get_val(copy=False)
generator_function = getattr(Random, random_type)
sample[:]=generator_function(dtype=f.dtype,
sample[()]=generator_function(dtype=f.dtype,
shape=sample.shape,
**kwargs)
return f
......@@ -344,7 +342,7 @@ class Field(object):
local_data = pindex
semiscaled_local_data = local_data.reshape(semiscaled_local_shape)
result_obj = np.empty(target_shape, dtype=pindex.dtype)
result_obj[:] = semiscaled_local_data
result_obj[()] = semiscaled_local_data
return result_obj
......@@ -494,23 +492,13 @@ class Field(object):
def _spec_to_rescaler(self, spec, result_list, power_space_index):
power_space = self.domain[power_space_index]
# weight the random fields with the power spectrum
# therefore get the pindex from the power space
pindex = power_space.pindex
# Now use numpy advanced indexing in order to put the entries of the
# power spectrum into the appropriate places of the pindex array.
# Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
local_pindex = pindex
local_blow_up = [slice(None)]*len(spec.shape)
# it is important to count from behind, since spec potentially grows
# with every iteration
index = self.domain_axes[power_space_index][0]-len(self.shape)
local_blow_up[index] = local_pindex
local_blow_up[index] = power_space.pindex
# here, the power_spectrum is distributed into the new shape
local_rescaler = spec[local_blow_up]
return local_rescaler
return spec[local_blow_up]
# ---Properties---
......@@ -668,19 +656,13 @@ class Field(object):
def real(self):
""" The real part of the field (data is not copied).
"""
real_part = self.val.real
result = self.copy_empty(dtype=real_part.dtype)
result.set_val(new_val=real_part, copy=False)
return result
return Field(self.domain,self.val.real)
@property
def imag(self):
""" The imaginary part of the field (data is not copied).
"""
real_part = self.val.imag
result = self.copy_empty(dtype=real_part.dtype)
result.set_val(new_val=real_part, copy=False)
return result
return Field(self.domain,self.val.imag)
# ---Special unary/binary operations---
......@@ -711,12 +693,9 @@ class Field(object):
"""
copied_val = self.get_val(copy=True)
new_field = self.copy_empty(
domain=domain,
dtype=dtype)
new_field.set_val(new_val=copied_val, copy=False)
return new_field
if domain is None:
domain = self.domain
return Field(domain=domain,val=self._val,dtype=dtype,copy=True)
def copy_empty(self, domain=None, dtype=None):
""" Returns an empty copy of the Field.
......@@ -748,41 +727,9 @@ class Field(object):
if domain is None:
domain = self.domain
else:
domain = self._parse_domain(domain)
if dtype is None:
dtype = self.dtype
else:
dtype = np.dtype(dtype)
fast_copyable = True
try:
for i in range(len(self.domain)):
if self.domain[i] is not domain[i]:
fast_copyable = False
break
except IndexError:
fast_copyable = False
if (fast_copyable and dtype == self.dtype):
new_field = self._fast_copy_empty()
else:
new_field = Field(domain=domain, dtype=dtype)
return new_field
def _fast_copy_empty(self):
# make an empty field
new_field = EmptyField()
# repair its class
new_field.__class__ = self.__class__
# copy domain, codomain and val
for key, value in list(self.__dict__.items()):
if key != '_val':
new_field.__dict__[key] = value
else:
new_field.__dict__[key] = np.empty_like(self.val)
return new_field
return Field(domain=domain, dtype=dtype)
def weight(self, power=1, inplace=False, spaces=None):
""" Weights the pixels of `self` with their invidual pixel-volume.
......@@ -805,10 +752,7 @@ class Field(object):
The weighted field.
"""
if inplace:
new_field = self
else:
new_field = self.copy_empty()
new_field = self if inplace else self.copy_empty()
new_val = self.get_val(copy=False)
......@@ -851,16 +795,10 @@ class Field(object):
"the NIFTy field class")
# Compute the dot respecting the fact of discrete/continuous spaces
if bare:
y = self
else:
y = self.weight(power=1)
y = self if bare else self.weight(power=1)
if spaces is None:
x_val = x.get_val(copy=False)
y_val = y.get_val(copy=False)
result = (y_val.conjugate() * x_val).sum()
return result
return np.vdot(y.val.flatten(),x.val.flatten())
else:
# create a diagonal operator which is capable of taking care of the
# axes-matching
......@@ -899,57 +837,27 @@ class Field(object):
"""
if inplace:
work_field = self
self.imag*=-1
return self
else:
work_field = self.copy_empty()
new_val = self.get_val(copy=False)
new_val = new_val.conjugate()
work_field.set_val(new_val=new_val, copy=False)
return work_field
return Field(self.domain,np.conj(self.val),self.dtype)
# ---General unary/contraction methods---
def __pos__(self):
""" x.__pos__() <==> +x
Returns a (positive) copy of `self`.
"""
return self.copy()
def __neg__(self):
""" x.__neg__() <==> -x
Returns a negative copy of `self`.
"""
return_field = self.copy_empty()
new_val = -self.get_val(copy=False)
return_field.set_val(new_val, copy=False)
return return_field
return Field(self.domain,-self.val,self.dtype)
def __abs__(self):
""" x.__abs__() <==> abs(x)
Returns an absolute valued copy of `self`.
"""
new_val = abs(self.get_val(copy=False))
return_field = self.copy_empty(dtype=new_val.dtype)
return_field.set_val(new_val, copy=False)
return return_field
return Field(self.domain,np.abs(self.val),self.dtype)
def _contraction_helper(self, op, spaces):
# build a list of all axes
if spaces is None:
spaces = range(len(self.domain))
else:
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
return getattr(self.val, op)()
# build a list of all axes
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
......@@ -1010,28 +918,14 @@ class Field(object):
# ---General binary methods---
def _binary_helper(self, other, op, inplace=False):
def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
try:
assert len(other.domain) == len(self.domain)
for index in range(len(self.domain)):
assert other.domain[index] == self.domain[index]
except AssertionError:
raise ValueError(
"domains are incompatible.")
other = other.get_val(copy=False)
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
return Field(self.domain,getattr(self.val,op)(other.val))
self_val = self.get_val(copy=False)
return_val = getattr(self_val, op)(other)
if inplace:
working_field = self
else:
working_field = self.copy_empty(dtype=return_val.dtype)
working_field.set_val(return_val, copy=False)
return working_field
return Field(self.domain,getattr(self.val,op)(other))
def __add__(self, other):
return self._binary_helper(other, op='__add__')
......@@ -1040,7 +934,7 @@ class Field(object):
return self._binary_helper(other, op='__radd__')
def __iadd__(self, other):
return self._binary_helper(other, op='__iadd__', inplace=True)
return self._binary_helper(other, op='__iadd__')
def __sub__(self, other):
return self._binary_helper(other, op='__sub__')
......@@ -1049,7 +943,7 @@ class Field(object):
return self._binary_helper(other, op='__rsub__')
def __isub__(self, other):
return self._binary_helper(other, op='__isub__', inplace=True)
return self._binary_helper(other, op='__isub__')
def __mul__(self, other):
return self._binary_helper(other, op='__mul__')
......@@ -1058,7 +952,7 @@ class Field(object):
return self._binary_helper(other, op='__rmul__')
def __imul__(self, other):
return self._binary_helper(other, op='__imul__', inplace=True)
return self._binary_helper(other, op='__imul__')
def __div__(self, other):
return self._binary_helper(other, op='__div__')
......@@ -1073,7 +967,7 @@ class Field(object):
return self._binary_helper(other, op='__rtruediv__')
def __idiv__(self, other):
return self._binary_helper(other, op='__idiv__', inplace=True)
return self._binary_helper(other, op='__idiv__')
def __pow__(self, other):
return self._binary_helper(other, op='__pow__')
......@@ -1082,7 +976,7 @@ class Field(object):
return self._binary_helper(other, op='__rpow__')
def __ipow__(self, other):
return self._binary_helper(other, op='__ipow__', inplace=True)
return self._binary_helper(other, op='__ipow__')
def __lt__(self, other):
return self._binary_helper(other, op='__lt__')
......@@ -1119,8 +1013,3 @@ class Field(object):
"\n- val = " + repr(self.get_val()) + \
"\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean)
class EmptyField(Field):
def __init__(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment