Commit 994893e5 authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

moment method redone.

parent 480470b4
......@@ -960,7 +960,7 @@ class Tensors(AbstractNdarray):
raise ValueError("Multiple occurences of value {}"
.format(tensor))
def moments(self, moment):
def moment(self, moment, weights=None):
"""
Returns:
Moments of the distribution.
......@@ -969,8 +969,27 @@ class Tensors(AbstractNdarray):
second as variance etc. Not 0 as it is mathematicaly correct.
Args:
moment (int): n-th moment
Examples:
>>> import tfields
Skalars
>>> t = tfields.Tensors(range(1, 6))
>>> assert t.moment(1, weights=[-2, -1, 20, 1, 2]) == 0.5
>>> assert t.moment(2, weights=[0.25, 1, 17.5, 1, 0.25]) == 0.2
Vectors
>>> t = tfields.Tensors(list(zip(range(1, 6), range(1, 6))))
>>> t.moment(1, weights=[-2, -1, 20, 1, 2])
Tensors([0.5, 0.5])
>>> t.moment(1, weights=list(zip([-2, -1, 10, 1, 2],
... [-2, -1, 20, 1, 2])))
Tensors([1. , 0.5])
"""
return tfields.lib.stats.moments(self, moment)
array = tfields.lib.stats.moment(self, moment, weights=weights)
if self.rank == 0: # scalar
array = [array]
return Tensors(array, coord_sys=self.coord_sys)
def closest(self, other, **kwargs):
"""
......@@ -1222,7 +1241,7 @@ class Tensors(AbstractNdarray):
def plot(self, **kwargs):
"""
Forwarding to tfields.lib.plotting.plotArray
Forwarding to tfields.lib.plotting.plot_array
"""
artist = tfields.plotting.plot_array(self, **kwargs)
return artist
......
......@@ -71,28 +71,184 @@ mean = np.mean
median = np.median
def getMoment(array, moment):
def _chk_asarray(a, axis):
"""
Returns:
Moments of the distribution.
Note:
The first moment is given as the mean,
second as variance etc. Not 0 as it is mathematicaly correct.
Args:
moment (int): n-th moment
copied from scipy.stats
"""
if axis is None:
a = np.ravel(a)
outaxis = 0
else:
a = np.asarray(a)
outaxis = axis
if a.ndim == 0:
a = np.atleast_1d(a)
return a, outaxis
def _contains_nan(a, nan_policy='propagate'):
"""
copied from scipy.stats
"""
policies = ['propagate', 'raise', 'omit']
if nan_policy not in policies:
raise ValueError("nan_policy must be one of {%s}" %
', '.join("'%s'" % s for s in policies))
try:
# Calling np.sum to avoid creating a huge array into memory
# e.g. np.isnan(a).any()
with np.errstate(invalid='ignore'):
contains_nan = np.isnan(np.sum(a))
except TypeError:
# If the check cannot be properly performed we fallback to omitting
# nan values and raising a warning. This can happen when attempting to
# sum things that are not numbers (e.g. as in the function `mode`).
contains_nan = False
nan_policy = 'omit'
warnings.warn("The input array could not be properly checked for nan "
"values. nan values will be ignored.", RuntimeWarning)
if contains_nan and nan_policy == 'raise':
raise ValueError("The input contains nan values")
return (contains_nan, nan_policy)
def moment(a, moment=1, axis=0, weights=None, nan_policy='propagate'):
r"""
Calculate the nth moment about the mean for a sample.
A moment is a specific quantitative measure of the shape of a set of
points. It is often used to calculate coefficients of skewness and kurtosis
due to its close relationship with them.
Parameters
----------
a : array_like
data
moment : int or array_like of ints, optional
order of central moment that is returned. Default is 1.
axis : int or None, optional
Axis along which the central moment is computed. Default is 0.
If None, compute over the whole array `a`.
nan_policy : {'propagate', 'raise', 'omit'}, optional
Defines how to handle when input contains nan. 'propagate' returns nan,
'raise' throws an error, 'omit' performs the calculations ignoring nan
values. Default is 'propagate'.
Returns
-------
n-th central moment : ndarray or float
The appropriate moment along the given axis or over all values if axis
is None. The denominator for the moment calculation is the number of
observations, no degrees of freedom correction is done.
See also
--------
kurtosis, skew, describe
Notes
-----
The k-th weighted central moment of a data sample is:
.. math::
m_k = \frac{1}{\sum_{j = 1}^n w_i} \sum_{i = 1}^n w_i (x_i - \bar{x})^k
Where n is the number of samples and x-bar is the mean. This function uses
exponentiation by squares [1]_ for efficiency.
References
----------
.. [1] http://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms
Examples
--------
>>> from tfields.lib.stats import moment
>>> moment([1, 2, 3, 4, 5], moment=0)
1.0
>>> moment([1, 2, 3, 4, 5], moment=1)
0.0
>>> moment([1, 2, 3, 4, 5], moment=2)
2.0
Expansion of the scipy.stats moment function by weights:
>>> moment([1, 2, 3, 4, 5], moment=1, weights=[-2, -1, 20, 1, 2])
0.5
>>> moment([1, 2, 3, 4, 5], moment=2, weights=[5, 4, 3, 2, 1])
2.0
>>> moment([1, 2, 3, 4, 5], moment=2, weights=[5, 4, 3, 2, 1])
2.0
>>> moment([1, 2, 3, 4, 5], moment=2, weights=[0.25, 1, 17.5, 1, 0.25])
0.2
>>> moment([1, 2, 3, 4, 5], moment=2, weights=[0, 0, 1, 0, 0])
0.0
"""
a, axis = _chk_asarray(a, axis)
contains_nan, nan_policy = _contains_nan(a, nan_policy)
if contains_nan and nan_policy == 'omit':
a = ma.masked_invalid(a)
return scipy.mstats_basic.moment(a, moment, axis)
if a.size == 0:
# empty array, return nan(s) with shape matching `moment`
if np.isscalar(moment):
return np.nan
else:
return np.ones(np.asarray(moment).shape, dtype=np.float64) * np.nan
# for array_like moment input, return a value for each.
if not np.isscalar(moment):
mmnt = [_moment(a, i, axis, weights=weights) for i in moment]
return np.array(mmnt)
else:
return _moment(a, moment, axis, weights=weights)
def _moment(a, moment, axis, weights=None):
if np.abs(moment - np.round(moment)) > 0:
raise ValueError("All moment parameters must be integers")
if moment == 0:
return 0
if moment == 1: # center of mass
return np.average(array, axis=0)
elif moment == 2: # variance
return np.var(array, axis=0)
elif moment == 3 and scipy.stats: # skewness
return scipy.stats.skew(array, axis=0)
elif moment == 4 and scipy.stats: # kurtosis
return scipy.stats.kurtosis(array, axis=0)
# When moment equals 0, the result is 1, by definition.
shape = list(a.shape)
del shape[axis]
if shape:
# return an actual array of the appropriate shape
return np.ones(shape, dtype=float)
else:
# the input was 1D, so return a scalar instead of a rank-0 array
return 1.0
elif weights is None and moment == 1:
# By definition the first moment about the mean is 0.
shape = list(a.shape)
del shape[axis]
if shape:
# return an actual array of the appropriate shape
return np.zeros(shape, dtype=float)
else:
# the input was 1D, so return a scalar instead of a rank-0 array
return np.float64(0.0)
else:
raise NotImplementedError("Moment %i not implemented." % moment)
# Exponentiation by squares: form exponent sequence
n_list = [moment]
current_n = moment
while current_n > 2:
if current_n % 2:
current_n = (current_n - 1) / 2
else:
current_n /= 2
n_list.append(current_n)
# Starting point for exponentiation by squares
a_zero_mean = a - np.expand_dims(np.mean(a, axis), axis)
if n_list[-1] == 1:
s = a_zero_mean.copy()
else:
s = a_zero_mean**2
# Perform multiplications
for n in n_list[-2::-1]:
s = s**2
if n % 2:
s *= a_zero_mean
return np.average(s, axis, weights=weights)
if __name__ == '__main__':
......
......@@ -169,6 +169,7 @@ def plot_array(array, **kwargs):
Points3D plotting method.
Args:
array (numpy array)
axis (matplotlib.Axis) object
xAxis (int): coordinate index that should be on xAxis
yAxis (int): coordinate index that should be on yAxis
......@@ -184,9 +185,9 @@ def plot_array(array, **kwargs):
tfields.plotting.set_default(kwargs, 'methodName', 'scatter')
po = tfields.plotting.PlotOptions(kwargs)
labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)'])
labels = po.pop('labels', ['x (m)', 'y (m)', 'z (m)'])
xAxis, yAxis, zAxis = po.getXYZAxis()
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList))
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labels))
if zAxis is None:
args = [array[:, xAxis],
array[:, yAxis]]
......@@ -314,8 +315,8 @@ def plot_mesh(vertices, faces, **kwargs):
artist._edgecolors2d = None
artist._facecolors2d = None
labelList = ['x (m)', 'y (m)', 'z (m)']
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList))
labels = ['x (m)', 'y (m)', 'z (m)']
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labels))
else:
raise NotImplementedError("Dimension != 2|3")
......@@ -432,9 +433,9 @@ def plot_function(fun, **kwargs):
Better Artist not list of Artists
"""
import numpy as np
labelList = ['x', 'f(x)']
labels = ['x', 'f(x)']
po = tfields.plotting.PlotOptions(kwargs)
tfields.plotting.set_labels(po.axis, *labelList)
tfields.plotting.set_labels(po.axis, *labels)
xMin, xMax = po.pop('xMin', 0), po.pop('xMax', 1)
n = po.pop('n', 100)
vals = np.linspace(xMin, xMax, n)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment