Commit a3d9b3ef authored by findesgh's avatar findesgh

Overload + and * for sdist.

parent dbf494d1
......@@ -78,6 +78,77 @@ class SizeDist(object):
r[ind] = self.func(sizes[ind])
return r
def __add__(self, other):
"""
Commutative addition of size distributions.
Overloading of ``+`` operator. Can add an instance of ``SizeDist`` (or
subclass thereof), any callable or a scalar to ``self.func``. No check
on units is performed. Returns an instance of ``self.__class__``, the
``func`` attribute of which is defined to return the corresponding sum.
If ``other`` is an instance of ``SizeDist`` (or subclass), take maximum
of the two ``sizeMin`` attributes and minimum of the two ``sizeMax``
attributes.
Parameters
----------
other : SizeDist, callable or scalar-valued Quantity [1/L]
Is added to ``self.func``.
Returns
-------
: ``self.__class__``
Instance of own class with corresponding ``func`` attribute.
Examples
--------
>>> import astropy.units as u
>>> def f(s): \
return 1.0/s.unit
>>> a = SizeDist(1*u.angstrom, 1*u.micron, f)
>>> b = SizeDist(10*u.angstrom, 10*u.micron, f)
>>> c = a + b
>>> c(_np.logspace(-11, -4, 10)*u.m)
<Quantity [0., 0., 0., 2., 2., 2., 2., 0., 0., 0.] 1 / m>
"""
if issubclass(other.__class__, SizeDist):
# find new size limits
sizeMin = max(self.sizeMin, other.sizeMin)
sizeMax = min(self.sizeMax, other.sizeMax)
# new differential number density is sum
def func(sizes):
return self.func(sizes) + other.func(sizes)
return self.__class__(sizeMin, sizeMax, func)
elif callable(other):
def func(sizes):
return self.func(sizes) + other(sizes)
return self.__class__(self.sizeMin, self.sizeMax, func)
else:
def func(sizes):
return other + self.func(sizes)
return self.__class__(self.sizeMin, self.sizeMax, func)
# make addition commutative
__radd__ = __add__
def __mul__(self, other):
if callable(other):
def func(sizes):
return self.function(sizes) * other(sizes)
return self.__class__(self.sizeMin, self.sizeMax, func)
else:
def func(sizes):
return other * self.function(sizes)
return self.__class__(self.sizeMin, self.sizeMax, func)
# make multiplication commutative
__rmul__ = __mul__
###############################################################################
if __name__ == "__main__":
......
......@@ -55,3 +55,23 @@ class TestSdist(TestCase):
with self.assertRaises(TypeError) as context:
a(0.5*u.micron)
return
def test_add_scalar_quantity_int(self):
scalar = 1/u.micron
a = sdist.SizeDist(3.5*u.angstrom, 1*u.micron, lambda x: 1/x.unit)
b = a + scalar
r = a(self.__class__._sizes)
r[np.where(r != 0)] = r[np.where(r != 0)] + scalar
self.assertTrue(
np.all(b(self.__class__._sizes) == r))
return
def test_add_scalar_quantity_float(self):
scalar = 2.5/u.micron
a = sdist.SizeDist(3.5*u.angstrom, 1*u.micron, lambda x: 1/x.unit)
b = a + scalar
r = a(self.__class__._sizes)
r[np.where(r != 0)] = r[np.where(r != 0)] + scalar
self.assertTrue(
np.all(b(self.__class__._sizes) == r))
return
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment