Commit 613bae17 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more locking and diagnostics

parent 99ebd010
Pipeline #25097 passed with stages
in 6 minutes and 38 seconds
......@@ -5,11 +5,11 @@ np.random.seed(42)
def adjust_zero_mode(m0, t0):
mtmp = m0.to_global_data()
mtmp = m0.to_global_data().copy()
zero_position = len(m0.shape)*(0,)
zero_mode = mtmp[zero_position]
mtmp[zero_position] = zero_mode / abs(zero_mode)
ttmp = t0.to_global_data()
ttmp = t0.to_global_data().copy()
ttmp[0] += 2 * np.log(abs(zero_mode))
return (ift.Field.from_global_data(m0.domain, mtmp),
ift.Field.from_global_data(t0.domain, ttmp))
......
......@@ -52,16 +52,22 @@ class PowerSpace(StructuredDomain):
def linear_binbounds(nbin, first_bound, last_bound):
"""Produces linearly spaced bin bounds.
This will produce a binbounds array with nbin-1 entries with
binbounds[0]=first_bound and binbounds[-1]=last_bound and the remaining
values equidistantly spaced (in linear scale) between these two.
Parameters
----------
nbin : int
the number of bins
first_bound, last_bound : float
the k values for the right boundary of the first bin and the left
boundary of the last bin, respectively. They are given in length
units of the harmonic partner space.
Returns
-------
numpy.ndarray(numpy.float64)
binbounds array with nbin-1 entries with
binbounds[0]=first_bound and binbounds[-1]=last_bound and the
remaining values equidistantly spaced (in linear scale) between
these two.
"""
nbin = int(nbin)
if nbin < 3:
......@@ -72,17 +78,22 @@ class PowerSpace(StructuredDomain):
def logarithmic_binbounds(nbin, first_bound, last_bound):
"""Produces logarithmically spaced bin bounds.
This will produce a binbounds array with nbin-1 entries with
binbounds[0]=first_bound and binbounds[-1]=last_bound and the remaining
values equidistantly spaced (in natural logarithmic scale)
between these two.
Parameters
----------
nbin : int
the number of bins
first_bound, last_bound : float
the k values for the right boundary of the first bin and the left
boundary of the last bin, respectively. They are given in length
units of the harmonic partner space.
Returns
-------
numpy.ndarray(numpy.float64)
binbounds array with nbin-1 entries with
binbounds[0]=first_bound and binbounds[-1]=last_bound and the
remaining values equidistantly spaced (in natural logarithmic
scale) between these two.
"""
nbin = int(nbin)
if nbin < 3:
......@@ -95,19 +106,24 @@ class PowerSpace(StructuredDomain):
def useful_binbounds(space, logarithmic, nbin=None):
"""Produces bin bounds suitable for a given domain.
This will produce a binbounds array with `nbin-1` entries, if `nbin` is
supplied, or the maximum number of entries that does not produce empty
bins, if `nbin` is not supplied.
The first and last bin boundary are inferred from `space`.
Parameters
----------
space : StructuredDomain
the domain for which the binbounds will be computed.
logarithmic : bool
If True bins will have equal size in linear space; otherwise they
will have equali size in logarithmic space.
will have equal size in logarithmic space.
nbin : int, optional
the number of bins
If None, the highest possible number of bins will be used
Returns
-------
numpy.ndarray(numpy.float64)
Binbounds array with `nbin-1` entries, if `nbin` is
supplied, or the maximum number of entries that does not produce
empty bins, if `nbin` is not supplied.
The first and last bin boundary are inferred from `space`.
"""
if not (isinstance(space, StructuredDomain) and space.harmonic):
raise ValueError("first argument must be a harmonic space.")
......
......@@ -95,7 +95,7 @@ class RGSpace(StructuredDomain):
def get_k_length_array(self):
if (not self.harmonic):
raise NotImplementedError
out = Field((self,), dtype=np.float64)
out = Field(self, dtype=np.float64)
oloc = out.local_data
ibegin = dobj.ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0]
......
......@@ -112,7 +112,7 @@ class CriticalPowerEnergy(Energy):
gradient = -self._theta
gradient += self.alpha-0.5
gradient += Tt
self._gradient = gradient
self._gradient = gradient.lock()
def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
......
......@@ -102,6 +102,7 @@ class NonlinearPowerEnergy(Energy):
self._value += 0.5 * self.position.vdot(Tpos)
self._gradient *= -1. / len(self.xi_sample_list)
self._gradient += Tpos
self._gradient.lock()
def at(self, position):
return self.__class__(position, self.d, self.N, self.xi, self.D,
......
......@@ -26,7 +26,7 @@ class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S,
inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d
self.d = d.lock()
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.ht = ht
......@@ -44,6 +44,7 @@ class NonlinearWienerFilterEnergy(Energy):
tmp = self.position.vdot(t1) + residual.vdot(t2)
self._value = 0.5 * tmp.real
self._gradient = t1 - self.LinearizedResponse.adjoint_times(t2)
self._gradient.lock()
def at(self, position):
return self.__class__(position, self.d, self.Instrument,
......
......@@ -54,6 +54,7 @@ class WienerFilterEnergy(Energy):
Dx = self._curvature(self.position)
self._value = 0.5*self.position.vdot(Dx) - self._j.vdot(self.position)
self._gradient = Dx - self._j
self._gradient.lock()
def at(self, position):
return self.__class__(position=position, d=None, R=self.R, N=self.N,
......
......@@ -86,7 +86,7 @@ class ConjugateGradient(Minimizer):
return energy, controller.ERROR
q *= -alpha
r += q
r = r + q
energy = energy.at_with_grad(energy.position - alpha*d, r)
......
......@@ -51,7 +51,7 @@ class Energy(with_metaclass(NiftyMeta, type('NewBase', (object,), {}))):
def __init__(self, position):
super(Energy, self).__init__()
self._position = position.copy()
self._position = position.lock()
def at(self, position):
""" Initializes and returns a new Energy object at the new position.
......
......@@ -48,7 +48,7 @@ class LineEnergy(object):
def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__()
self._line_position = float(line_position)
self._line_direction = line_direction
self._line_direction = line_direction.lock()
if self._line_position == float(offset):
self._energy = energy
......
......@@ -35,6 +35,7 @@ class QuadraticEnergy(Energy):
else:
Ax = self._A(self.position)
self._grad = Ax - self._b
self._grad.lock()
self._value = 0.5*self.position.vdot(Ax) - b.vdot(self.position)
def at(self, position):
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from __future__ import division, print_function
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
......@@ -57,7 +57,7 @@ class ScipyMinimizer(Minimizer):
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
self._energy = self._energy.at(pos)
self._energy = self._energy.at(pos.locked_copy())
status = self._controller.check(self._energy)
if status != self._controller.CONTINUE:
raise _MinimizationDone
......@@ -68,13 +68,13 @@ class ScipyMinimizer(Minimizer):
def jac(self, x):
self._update(x)
return self._energy.gradient.val.reshape(-1)
return self._energy.gradient.val.flatten()
def hessp(self, x, p):
self._update(x)
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.reshape(-1)
return res.val.flatten()
import scipy.optimize as opt
hlp = _MinHelper(self._controller, energy)
......@@ -82,19 +82,24 @@ class ScipyMinimizer(Minimizer):
status = self._controller.start(hlp._energy)
if status != self._controller.CONTINUE:
return hlp._energy, status
x = hlp._energy.position.val.flatten()
try:
if self._need_hessp:
opt.minimize(hlp.fun, hlp._energy.position.val.reshape(-1),
r = opt.minimize(hlp.fun, x,
method=self._method, jac=hlp.jac,
hessp=hlp.hessp,
options=self._options)
else:
opt.minimize(hlp.fun, hlp._energy.position.val.reshape(-1),
r = opt.minimize(hlp.fun, x,
method=self._method, jac=hlp.jac,
options=self._options)
except _MinimizationDone:
status = self._controller.check(hlp._energy)
return hlp._energy, self._controller.check(hlp._energy)
if not r.success:
print("Problem in Scipy minimization:", r.message)
else:
print("Problem in Scipy minimization")
return hlp._energy, self._controller.ERROR
......
......@@ -74,7 +74,7 @@ class DiagonalOperator(EndomorphicOperator):
if self._spaces == tuple(range(len(self._domain))):
self._spaces = None # shortcut
self._diagonal = diagonal.locked_copy()
self._diagonal = diagonal.lock()
if self._spaces is not None:
active_axes = []
......
......@@ -79,10 +79,8 @@ class Test_Functionality(unittest.TestCase):
ps1 += sp.sum(spaces=1)/fp2.sum()
ps2 += sp.sum(spaces=0)/fp1.sum()
assert_allclose((ps1/samples).to_global_data(),
fp1.to_global_data(), rtol=0.2)
assert_allclose((ps2/samples).to_global_data(),
fp2.to_global_data(), rtol=0.2)
assert_allclose((ps1/samples).local_data, fp1.local_data, rtol=0.2)
assert_allclose((ps2/samples).local_data, fp2.local_data, rtol=0.2)
@expand(product([ift.RGSpace((8,), harmonic=True),
ift.RGSpace((8, 8), harmonic=True, distances=0.123)],
......@@ -118,10 +116,8 @@ class Test_Functionality(unittest.TestCase):
ps1 += sp.sum(spaces=1)/fp2.sum()
ps2 += sp.sum(spaces=0)/fp1.sum()
assert_allclose((ps1/samples).to_global_data(),
fp1.to_global_data(), rtol=0.2)
assert_allclose((ps2/samples).to_global_data(),
fp2.to_global_data(), rtol=0.2)
assert_allclose((ps1/samples).local_data, fp1.local_data, rtol=0.2)
assert_allclose((ps2/samples).local_data, fp2.local_data, rtol=0.2)
def test_vdot(self):
s = ift.RGSpace((10,))
......@@ -135,4 +131,4 @@ class Test_Functionality(unittest.TestCase):
x2 = ift.RGSpace((150,))
m = ift.Field.full((x1, x2), .5)
res = m.vdot(m, spaces=1)
assert_allclose(res.to_global_data(), 37.5)
assert_allclose(res.local_data, 37.5)
......@@ -45,8 +45,7 @@ class FFTOperatorTests(unittest.TestCase):
inp = ift.Field.from_random(domain=a, random_type='normal',
std=7, mean=3, dtype=itp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.to_global_data(), out.to_global_data(),
rtol=tol, atol=tol)
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
a, b = b, a
......@@ -54,8 +53,7 @@ class FFTOperatorTests(unittest.TestCase):
inp = ift.Field.from_random(domain=a, random_type='normal',
std=7, mean=3, dtype=itp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.to_global_data(), out.to_global_data(),
rtol=tol, atol=tol)
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
@expand(product([12, 15], [9, 12], [0.1, 1, 3.7],
[0.4, 1, 2.7],
......@@ -70,8 +68,7 @@ class FFTOperatorTests(unittest.TestCase):
inp = ift.Field.from_random(domain=a, random_type='normal',
std=7, mean=3, dtype=itp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.to_global_data(), out.to_global_data(),
rtol=tol, atol=tol)
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
a, b = b, a
......@@ -79,8 +76,7 @@ class FFTOperatorTests(unittest.TestCase):
inp = ift.Field.from_random(domain=a, random_type='normal',
std=7, mean=3, dtype=itp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.to_global_data(), out.to_global_data(),
rtol=tol, atol=tol)
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
@expand(product([0, 1, 2],
[np.float64, np.float32, np.complex64, np.complex128]))
......@@ -93,8 +89,7 @@ class FFTOperatorTests(unittest.TestCase):
inp = ift.Field.from_random(domain=(a1, a2, a3), random_type='normal',
std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.to_global_data(), out.to_global_data(),
rtol=tol, atol=tol)
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(128, distances=3.76, harmonic=True),
ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True),
......@@ -112,5 +107,4 @@ class FFTOperatorTests(unittest.TestCase):
zero_idx = tuple([0]*len(space.shape))
assert_allclose(inp.to_global_data()[zero_idx], out.integrate(),
rtol=tol, atol=tol)
assert_allclose(out.to_global_data(), out2.to_global_data(),
rtol=tol, atol=tol)
assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment