Commit c51456ea authored by Martin Reinecke's avatar Martin Reinecke

various tweaks, improve coverage

parent 77c29cf3
Pipeline #29511 failed with stages
in 3 minutes and 54 seconds
......@@ -145,20 +145,29 @@ class data_object(object):
def sum(self, axis=None):
return self._contraction_helper("sum", MPI.SUM, axis)
def prod(self, axis=None):
return self._contraction_helper("prod", MPI.PROD, axis)
def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis)
def mean(self):
return self.sum()/self.size
def mean(self, axis=None):
if axis is None:
sz = self.size
else:
sz = reduce(lambda x, y: x*y, [self.shape[i] for i in axis])
return self.sum(axis)/sz
def std(self):
return np.sqrt(self.var())
def std(self, axis=None):
return np.sqrt(self.var(axis))
# FIXME: to be improved!
def var(self):
def var(self, axis=None):
if axis is not None and len(axis) != len(self.shape):
raise ValueError("functionality not yet supported")
return (abs(self-self.mean())**2).mean()
def _binary_helper(self, other, op):
......
......@@ -235,6 +235,7 @@ class Field(object):
The value to fill the field with.
"""
self._val.fill(fill_value)
return self
def lock(self):
"""Write-protect the data content of `self`.
......@@ -318,6 +319,17 @@ class Field(object):
"""
return Field(val=self, copy=True)
def empty_copy(self):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return Field(self._domain, dtype=self.dtype)
def locked_copy(self):
""" Returns a read-only version of the Field.
......@@ -451,8 +463,8 @@ class Field(object):
or Field (for partial dot products)
"""
if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class")
raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
......@@ -642,7 +654,8 @@ class Field(object):
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('mean', spaces)
# MR FIXME: not very efficient
tmp = self.weight(1)
# MR FIXME: do we need "spaces" here?
tmp = self.weight(1, spaces)
return tmp.sum(spaces)*(1./tmp.total_volume(spaces))
def var(self, spaces=None):
......@@ -665,12 +678,10 @@ class Field(object):
# MR FIXME: not very efficient or accurate
m1 = self.mean(spaces)
if np.issubdtype(self.dtype, np.complexfloating):
sq = abs(self)**2
m1 = abs(m1)**2
sq = abs(self-m1)**2
else:
sq = self**2
m1 **= 2
return sq.mean(spaces) - m1
sq = (self-m1)**2
return sq.mean(spaces)
def std(self, spaces=None):
"""Determines the standard deviation over the sub-domains given by
......@@ -690,8 +701,10 @@ class Field(object):
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
"""
from .sugar import sqrt
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces)
print(self.var(spaces))
return sqrt(self.var(spaces))
def copy_content_from(self, other):
......
......@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
# RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j
energy = QuadraticEnergy(j*0., D_inv, j)
energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
......
......@@ -57,6 +57,18 @@ class MultiField(object):
dtype[key], **kwargs)
for key in domain.keys()})
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for val in self._val.values():
val.fill(fill_value)
return self
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
......@@ -76,6 +88,9 @@ class MultiField(object):
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
def empty_copy(self):
return MultiField({key: val.empty_copy() for key, val in self.items()})
@staticmethod
def build_dtype(dtype, domain):
if isinstance(dtype, dict):
......
......@@ -67,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode:
return self._op.apply(x, mode)
x0 = x*0.
x0 = x.empty_copy().fill(0.)
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation
......
......@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator):
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return x*0.
return x.empty_copy().fill(0.)
if mode == self.TIMES:
return x*self._factor
......
......@@ -218,16 +218,16 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
func2(value, out=out[key])
return out
return MultiField({key: func2(val) for key, val in x.items()})
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
fu = getattr(dobj, f)
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
elif isinstance(x, Field):
fu = getattr(dobj, f)
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
else:
return Field(domain=x._domain, val=fu(x.val))
return getattr(np, f)(x, out)
return func2
setattr(_current_module, f, func(f))
......@@ -18,7 +18,7 @@
import unittest
import numpy as np
from numpy.testing import assert_equal, assert_allclose
from numpy.testing import assert_equal, assert_allclose, assert_raises
from itertools import product
import nifty4 as ift
from test.common import expand
......@@ -124,6 +124,122 @@ class Test_Functionality(unittest.TestCase):
res = m.vdot(m, spaces=1)
assert_allclose(res.local_data, 37.5)
def test_lock(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.locked, False)
f1.lock()
assert_equal(f1.locked, True)
with assert_raises(ValueError):
f1 += f1
assert_equal(f1.locked_copy() is f1, True)
def test_fill(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.fill(10).local_data, 10)
def test_dataconv(self):
s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
gd = np.arange(s1.shape[0])
assert_equal(ld, ift.from_local_data(s1, ld).local_data)
assert_equal(gd, ift.from_global_data(s1, gd).to_global_data())
def test_cast_domain(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((10,), distances=20.)
d = np.arange(s1.shape[0])
d2 = ift.from_global_data(s1, d).cast_domain(s2).to_global_data()
assert_equal(d, d2)
def test_empty_domain(self):
f = ift.Field((), 5)
assert_equal(f.to_global_data(), 5)
f = ift.Field(None, 5)
assert_equal(f.to_global_data(), 5)
assert_equal(f.empty_copy().domain, f.domain)
assert_equal(f.empty_copy().dtype, f.dtype)
assert_equal(f.copy().domain, f.domain)
assert_equal(f.copy().dtype, f.dtype)
assert_equal(f.copy().local_data, f.local_data)
assert_equal(f.copy() is f, False)
def test_trivialities(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.local_data, f1.real.local_data)
f1 = ift.Field(s1, 27.+3j)
assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.local_data, +f1.local_data)
assert_equal(f1.sum(), f1.sum(0))
f1 = ift.from_global_data(s1, np.arange(10))
assert_equal(f1.min(), 0)
assert_equal(f1.max(), 9)
assert_equal(f1.prod(), 0)
def test_weight(self):
s1 = ift.RGSpace((10,))
f = ift.Field(s1, 10.)
f2 = f.copy()
f.weight(1, out=f2)
assert_equal(f.weight(1).local_data, f2.local_data)
assert_equal(f.total_volume(), 1)
assert_equal(f.total_volume(0), 1)
assert_equal(f.total_volume((0,)), 1)
assert_equal(f.scalar_weight(), 0.1)
assert_equal(f.scalar_weight(0), 0.1)
assert_equal(f.scalar_weight((0,)), 0.1)
s1 = ift.GLSpace(10)
f = ift.Field(s1, 10.)
assert_equal(f.scalar_weight(), None)
assert_equal(f.scalar_weight(0), None)
assert_equal(f.scalar_weight((0,)), None)
@expand(product([ift.RGSpace(10), ift.GLSpace(10)],
[np.float64, np.complex128]))
def test_reduction(self, dom, dt):
s1 = ift.Field(dom, 1., dtype=dt)
assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14)
assert_allclose(s1.var(0), 0., atol=1e-14)
assert_allclose(s1.std(), 0., atol=1e-14)
assert_allclose(s1.std(0), 0., atol=1e-14)
def test_err(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((11,))
f1 = ift.Field(s1, 27)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1.val)
with assert_raises(TypeError):
f2 = ift.Field(s2, "xyz")
with assert_raises(TypeError):
if f1:
pass
with assert_raises(TypeError):
f1.full((2, 4, 6))
with assert_raises(TypeError):
f2 = ift.Field(None, None)
with assert_raises(ValueError):
f2 = ift.Field(s1, None)
with assert_raises(ValueError):
f1.imag
with assert_raises(TypeError):
f1.vdot(42)
with assert_raises(ValueError):
f1.vdot(ift.Field(s2, 1.))
with assert_raises(TypeError):
f1.copy_content_from(1)
with assert_raises(ValueError):
f1.copy_content_from(ift.Field(s2, 1.))
with assert_raises(TypeError):
ift.full(s1, [2, 3])
def test_stdfunc(self):
s = ift.RGSpace((200,))
f = ift.Field(s, 27)
......
......@@ -30,7 +30,7 @@ spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)]
minimizers = ['ift.VL_BFGS(IC)',
'ift.NonlinearCG(IC, "Polak-Ribiere")',
#'ift.NonlinearCG(IC, "Hestenes-Stiefel"),
# 'ift.NonlinearCG(IC, "Hestenes-Stiefel"),
'ift.NonlinearCG(IC, "Fletcher-Reeves")',
'ift.NonlinearCG(IC, "5.49")',
'ift.NewtonCG(xtol=1e-5, maxiter=1000)',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment