Commit 9523951a authored by Martin Reinecke's avatar Martin Reinecke Committed by Philipp Arras
Browse files

intermediate stage; broken

parent b49ed2f5
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -125,7 +125,7 @@ class LMSpace(StructuredDomain):
# evaluate the kernel function at the required thetas
kernel_sphere = Field.from_raw(gl, func(theta))
# normalize the kernel such that the integral over the sphere is 4pi
kernel_sphere = kernel_sphere * (4 * np.pi / kernel_sphere.integrate())
kernel_sphere = kernel_sphere * (4 * np.pi / kernel_sphere.s_integrate())
# compute the spherical harmonic coefficients of the kernel
op = HarmonicTransformOperator(lm0, gl)
kernel_lm = op.adjoint_times(kernel_sphere.weight(1)).val
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -168,7 +168,7 @@ class RGSpace(StructuredDomain):
op = HarmonicTransformOperator(self, self.get_default_codomain())
dist = op.target[0]._get_dist_array()
kernel = Field(op.target, func(dist.val))
kernel = kernel / kernel.integrate()
kernel = kernel / kernel.s_integrate()
return op.adjoint_times(kernel.weight(1))
def get_default_codomain(self):
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -44,8 +44,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
return
f1 = from_random("normal", op.domain, dtype=domain_dtype)
f2 = from_random("normal", op.target, dtype=target_dtype)
res1 = f1.vdot(op.adjoint_times(f2))
res2 = op.times(f1).vdot(f2)
res1 = f1.s_vdot(op.adjoint_times(f2))
res2 = op.times(f1).s_vdot(f2)
if only_r_linear:
res1, res2 = res1.real, res2.real
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
......@@ -218,7 +218,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def _get_acceptable_location(op, loc, lin):
if not np.isfinite(lin.val.sum()):
if not np.isfinite(lin.val.s_sum()):
raise ValueError('Initial value must be finite')
dir = from_random("normal", loc.domain)
dirder = lin.jac(dir)
......@@ -231,7 +231,7 @@ def _get_acceptable_location(op, loc, lin):
try:
loc2 = loc+dir
lin2 = op(Linearization.make_var(loc2, lin.want_metric))
if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20:
if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20:
break
except FloatingPointError:
pass
......@@ -285,7 +285,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val)
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
if (abs(numgrad-dirder) <= xtol).all():
if (abs(numgrad-dirder) <= xtol).s_all():
break
dir = dir*0.5
dirnorm *= 0.5
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -284,7 +284,7 @@ class Field(object):
from .operators.outer_product_operator import OuterProduct
return OuterProduct(self, x.domain)(x)
def vdot(self, x=None, spaces=None):
def vdot(self, x, spaces=None):
"""Computes the dot product of 'self' with x.
Parameters
......@@ -312,11 +312,33 @@ class Field(object):
spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
return np.vdot(self._val, x._val)
return Field.scalar(np.array(np.vdot(self._val, x._val)))
# If we arrive here, we have to do a partial dot product.
# For the moment, do this the explicit, non-optimized way
return (self.conjugate()*x).sum(spaces=spaces)
def s_vdot(self, x):
"""Computes the dot product of 'self' with x.
Parameters
----------
x : Field
x must be defined on the same domain as `self`.
Returns
-------
float or complex
The dot product
"""
if not isinstance(x, Field):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
return np.vdot(self._val, x._val)
def norm(self, ord=2):
"""Computes the L2-norm of the field values.
......@@ -357,7 +379,7 @@ class Field(object):
def _contraction_helper(self, op, spaces):
if spaces is None:
return getattr(self._val, op)()
return Field.scalar(getattr(self._val, op)())
spaces = utilities.parse_spaces(spaces, len(self._domain))
......@@ -371,7 +393,7 @@ class Field(object):
# check if the result is scalar or if a result_field must be constr.
if np.isscalar(data):
return data
return Field.scalar(data)
else:
return_domain = tuple(dom
for i, dom in enumerate(self._domain)
......@@ -379,6 +401,9 @@ class Field(object):
return Field(DomainTuple.make(return_domain), data)
# def _s_contraction_helper(self, op):
# return getattr(self._val, op)()
def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`.
......@@ -390,12 +415,21 @@ class Field(object):
Returns
-------
Field or scalar
The result of the summation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
Field
The result of the summation.
"""
return self._contraction_helper('sum', spaces)
def s_sum(self):
"""Returns the sum over all entries
Returns
-------
scalar
The result of the summation.
"""
return self._val.sum()
def integrate(self, spaces=None):
"""Integrates over the sub-domains given by `spaces`.
......@@ -410,9 +444,8 @@ class Field(object):
Returns
-------
Field or scalar
The result of the integration. If it is carried out over the
entire domain, this is a scalar, otherwise a Field.
Field
The result of the integration.
"""
swgt = self.scalar_weight(spaces)
if swgt is not None:
......@@ -422,6 +455,23 @@ class Field(object):
tmp = self.weight(1, spaces=spaces)
return tmp.sum(spaces)
def s_integrate(self):
"""Integrates over the Field.
Integration is performed by summing over `self` multiplied by its
volume factors.
Returns
-------
Scalar
The result of the integration.
"""
swgt = self.scalar_weight()
if swgt is not None:
return self.s_sum()*swgt
tmp = self.weight(1)
return tmp.s_sum()
def prod(self, spaces=None):
"""Computes the product over the sub-domains given by `spaces`.
......@@ -434,18 +484,26 @@ class Field(object):
Returns
-------
Field or scalar
The result of the product. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
Field
The result of the product.
"""
return self._contraction_helper('prod', spaces)
def s_prod(self):
return self._val.prod()
def all(self, spaces=None):
return self._contraction_helper('all', spaces)
def s_all(self):
return self._val.all()
def any(self, spaces=None):
return self._contraction_helper('any', spaces)
def s_any(self):
return self._val.any()
# def min(self, spaces=None):
# """Determines the minimum over the sub-domains given by `spaces`.
#
......@@ -457,9 +515,8 @@ class Field(object):
#
# Returns
# -------
# Field or scalar
# The result of the operation. If it is carried out over the entire
# domain, this is a scalar, otherwise a Field.
# Field
# The result of the operation.
# """
# return self._contraction_helper('min', spaces)
#
......@@ -474,9 +531,8 @@ class Field(object):
#
# Returns
# -------
# Field or scalar
# The result of the operation. If it is carried out over the entire
# domain, this is a scalar, otherwise a Field.
# Field
# The result of the operation.
# """
# return self._contraction_helper('max', spaces)
......@@ -494,9 +550,8 @@ class Field(object):
Returns
-------
Field or scalar
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
Field
The result of the operation.
"""
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('mean', spaces)
......@@ -505,6 +560,19 @@ class Field(object):
tmp = self.weight(1, spaces)
return tmp.sum(spaces)*(1./tmp.total_volume(spaces))
def s_mean(self):
"""Determines the field mean
``x.s_mean()`` is equivalent to
``x.s_integrate()/x.total_volume()``.
Returns
-------
scalar
The result of the operation.
"""
return self.s_integrate()/self.total_volume()
def var(self, spaces=None):
"""Determines the variance over the sub-domains given by `spaces`.
......@@ -517,9 +585,8 @@ class Field(object):
Returns
-------
Field or scalar
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
Field
The result of the operation.
"""
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('var', spaces)
......@@ -531,6 +598,24 @@ class Field(object):
sq = (self-m1)**2
return sq.mean(spaces)
def s_var(self):
"""Determines the field variance
Returns
-------
scalar
The result of the operation.
"""
if self.scalar_weight() is not None:
return self._val.var()
# MR FIXME: not very efficient or accurate
m1 = self.s_mean()
if utilities.iscomplextype(self.dtype):
sq = abs(self-m1)**2
else:
sq = (self-m1)**2
return sq.s_mean()
def std(self, spaces=None):
"""Determines the standard deviation over the sub-domains given by
`spaces`.
......@@ -546,15 +631,29 @@ class Field(object):
Returns
-------
Field or scalar
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
Field
The result of the operation.
"""
from .sugar import sqrt
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces))
def s_std(self):
"""Determines the standard deviation of the Field.
``x.s_std()`` is equivalent to ``sqrt(x.s_var())``.
Returns
-------
scalar
The result of the operation.
"""
from .sugar import sqrt
if self.scalar_weight() is not None:
return self._val.std()
return np.sqrt(self.s_var())
def __repr__(self):
return "<nifty6.Field>"
......
......@@ -234,10 +234,10 @@ class Linearization(object):
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
return self.new(
Field.scalar(self._val.vdot(other)),
self._val.vdot(other),
VdotOperator(other)(self._jac))
return self.new(
Field.scalar(self._val.vdot(other._val)),
self._val.vdot(other._val),
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
......@@ -256,14 +256,9 @@ class Linearization(object):
the (partial) sum
"""
from .operators.contraction_operator import ContractionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.sum()),
ContractionOperator(self._jac.target, None)(self._jac))
else:
return self.new(
self._val.sum(spaces),
ContractionOperator(self._jac.target, spaces)(self._jac))
return self.new(
self._val.sum(spaces),
ContractionOperator(self._jac.target, spaces)(self._jac))
def integrate(self, spaces=None):
"""Computes the (partial) integral over self
......@@ -280,14 +275,9 @@ class Linearization(object):
the (partial) integral
"""
from .operators.contraction_operator import ContractionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.integrate()),
ContractionOperator(self._jac.target, None, 1)(self._jac))
else:
return self.new(
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
return self.new(
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def exp(self):
tmp = self._val.exp()
......
......@@ -111,13 +111,16 @@ class MultiField(object):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
def s_vdot(self, x):
result = 0.
self._check_domain(x)
for v1, v2 in zip(self._val, x._val):
result += v1.vdot(v2)
result += v1.s_vdot(v2)
return result
def vdot(self, x):
return Field.scalar(self.s_vdot(x))
# @staticmethod
# def build_dtype(dtype, domain):
# if isinstance(dtype, dict):
......
......@@ -60,7 +60,7 @@ class Squared2NormOperator(EnergyOperator):
def apply(self, x, difforder):
self._check_input(x)
res = Field.scalar(x.vdot(x))
res = x.vdot(x)
if difforder == self.VALUE_ONLY:
return res
jac = VdotOperator(2*x)
......@@ -90,7 +90,7 @@ class QuadraticFormOperator(EnergyOperator):
def apply(self, x, difforder):
self._check_input(x)
t1 = self._op(x)
res = Field.scalar(0.5*x.vdot(t1))
res = 0.5*x.vdot(t1)
if difforder == self.VALUE_ONLY:
return res
return Linearization(res, VdotOperator(t1))
......@@ -132,9 +132,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum())
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
if difforder <= self.WITH_JAC:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
return res.add_metric(makeOp(MultiField.from_dict(mf)))
......@@ -240,9 +238,7 @@ class PoissonianEnergy(EnergyOperator):
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = x.sum() - x.log().vdot(self._d)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
if difforder <= self.WITH_JAC:
return res
return res.add_metric(makeOp(1./x.val))
......@@ -284,9 +280,7 @@ class InverseGammaLikelihood(EnergyOperator):
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
if difforder <= self.WITH_JAC:
return res
return res.add_metric(makeOp(self._alphap1/(x.val**2)))
......@@ -317,9 +311,7 @@ class StudentTEnergy(EnergyOperator):
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
if difforder <= self.WITH_JAC:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
return res.add_metric(met)
......@@ -355,9 +347,7 @@ class BernoulliEnergy(EnergyOperator):
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
if difforder <= self.WITH_JAC:
return res
met = makeOp(1./(x.val*(1. - x.val)))
met = SandwichOperator.make(x.jac, met)
......
......@@ -41,7 +41,7 @@ class VdotOperator(LinearOperator):
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES:
return Field.scalar(self._field.vdot(x))
return self._field.vdot(x)
return self._field*x.val[()]
......
......@@ -69,8 +69,8 @@ def test_power_synthesize_analyze(space1, space2):
sk = opfull.draw_sample()
sp = ift.power_analyze(sk, spaces=(0, 1), keep_phase_information=False)
sc1.add(sp.sum(spaces=1)/fp2.sum())
sc2.add(sp.sum(spaces=0)/fp1.sum())
sc1.add(sp.sum(spaces=1)/fp2.s_sum())
sc2.add(sp.sum(spaces=0)/fp1.s_sum())
assert_allclose(sc1.mean.val, fp1.val, rtol=0.2)
assert_allclose(sc2.mean.val, fp2.val, rtol=0.2)
......@@ -98,8 +98,8 @@ def test_DiagonalOperator_power_analyze2(space1, space2):
for ii in range(samples):
sk = S_full.draw_sample()
sp = ift.power_analyze(sk, spaces=(0, 1), keep_phase_information=False)
sc1.add(sp.sum(spaces=1)/fp2.sum())
sc2.add(sp.sum(spaces=0)/fp1.sum())
sc1.add(sp.sum(spaces=1)/fp2.s_sum())
sc2.add(sp.sum(spaces=0)/fp1.s_sum())
assert_allclose(sc1.mean.val, fp1.val, rtol=0.2)
assert_allclose(sc2.mean.val, fp2.val, rtol=0.2)
......@@ -124,8 +124,8 @@ def test_vdot():
s = ift.RGSpace((10,))
f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
assert_allclose(f1.s_vdot(f2), f1.vdot(f2, spaces=0).val)
assert_allclose(f1.s_vdot(f2), np.conj(f2.s_vdot(f1)))
def test_vdot2():
......@@ -154,7 +154,7 @@ def test_sum():
), distances=(0.3,))
m1 = ift.Field(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.sum()
res1 = m1.s_sum()
res2 = m2.sum(spaces=1)
assert_allclose(res1, 36)
assert_allclose(res2.val, np.full(9, 2*12*0.45))
......@@ -165,7 +165,7 @@ def test_integrate():
x2 = ift.RGSpace((2, 12), distances=(0.3,))
m1 = ift.Field(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.integrate()
res1 = m1.s_integrate()
res2 = m2.integrate(spaces=1)
assert_allclose(res1, 36*2)
assert_allclose(res2.val, np.full(9, 2*12*0.45*0.3**2))
......@@ -204,13 +204,13 @@ def test_trivialities():
assert_equal(f1.one_over().val, (1./f1).val)
assert_equal(f1.real.val, 27.)
assert_equal(f1.imag.val, 3.)
assert_equal(f1.sum(), f1.sum(0))
assert_equal(f1.s_sum(), f1.sum(0).val)
assert_equal(f1.conjugate().val,
ift.Field.full(s1, 27. - 3j).val)
f1 = ift.makeField(s1, np.arange(10))
# assert_equal(f1.min(), 0)
# assert_equal(f1.max(), 9)
assert_equal(f1.prod(), 0)
assert_equal(f1.s_prod(), 0)
def test_weight():
......@@ -241,12 +241,12 @@ def test_weight():
@pmp('dt', [np.float64, np.complex128])
def test_reduction(dom, dt):
s1 = ift.Field.full(dom, dt(1.))
assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14)
assert_allclose(s1.var(0), 0., atol=1e-14)
assert_allclose(s1.std(), 0., atol=1e-14)
assert_allclose(s1.std(0), 0., atol=1e-14)
assert_allclose(s1.s_mean(), 1.)
assert_allclose(s1.mean(0).val, 1.)
assert_allclose(s1.s_var(), 0., atol=1e-14)
assert_allclose(s1.var(0).val, 0., atol=1e-14)
assert_allclose(s1.s_std(), 0., atol=1e-14)
assert_allclose(s1.std(0).val, 0., atol=1e-14)
def test_err():
......@@ -322,12 +322,12 @@ def test_stdfunc():
def test_emptydomain():
f = ift.Field.full((), 3.)
assert_equal(f.sum(), 3.)
assert_equal(f.prod(), 3.)
assert_equal(f.s_sum(), 3.)
assert_equal(f.s_prod(), 3.)
assert_equal(f.val, 3.)
assert_equal(f.val.shape, ())