Commit 9f6f3ff2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge master

parents 8db130ad 69db50ec
Pipeline #15745 failed with stage
in 10 minutes and 44 seconds
...@@ -47,7 +47,9 @@ test_mpi_fftw_hdf5: ...@@ -47,7 +47,9 @@ test_mpi_fftw_hdf5:
- ci/install_pyfftw.sh - ci/install_pyfftw.sh
- ci/install_h5py.sh - ci/install_h5py.sh
- python setup.py build_ext --inplace - python setup.py build_ext --inplace
- nosetests -vv --with-coverage --cover-package=nifty --cover-branches - mpiexec --allow-run-as-root -n 2 nosetests -x
- mpiexec --allow-run-as-root -n 4 nosetests -x
- nosetests -x --with-coverage --cover-package=nifty --cover-branches
- > - >
coverage report | grep TOTAL | awk '{ print "TOTAL: "$6; }' coverage report | grep TOTAL | awk '{ print "TOTAL: "$6; }'
......
from nifty import * import numpy as np
from nifty.library.wiener_filter import WienerFilterEnergy from nifty import (VL_BFGS, DiagonalOperator, FFTOperator, Field,
LinearOperator, PowerSpace, RelaxedNewton, RGSpace,
SteepestDescent, create_power_operator, exp, log, sqrt)
from nifty.library.critical_filter import CriticalPowerEnergy from nifty.library.critical_filter import CriticalPowerEnergy
import plotly.offline as pl from nifty.library.wiener_filter import WienerFilterEnergy
import plotly.graph_objs as go
import plotly.graph_objs as go
import plotly.offline as pl
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
rank = comm.rank rank = comm.rank
np.random.seed(42) np.random.seed(42)
def plot_parameters(m,t,p, p_d): def plot_parameters(m, t, p, p_d):
x = log(t.domain[0].kindex) x = log(t.domain[0].kindex)
m = fft.adjoint_times(m) m = fft.adjoint_times(m)
...@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d): ...@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
p = p.val.get_full_data().real p = p.val.get_full_data().real
p_d = p_d.val.get_full_data().real p_d = p_d.val.get_full_data().real
pl.plot([go.Heatmap(z=m)], filename='map.html') pl.plot([go.Heatmap(z=m)], filename='map.html')
pl.plot([go.Scatter(x=x,y=t), go.Scatter(x=x ,y=p), go.Scatter(x=x, y=p_d)], filename="t.html") pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
go.Scatter(x=x, y=p_d)], filename="t.html")
class AdjointFFTResponse(LinearOperator): class AdjointFFTResponse(LinearOperator):
...@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator): ...@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
def _adjoint_times(self, x, spaces=None): def _adjoint_times(self, x, spaces=None):
return self.FFT(self.R.adjoint_times(x)) return self.FFT(self.R.adjoint_times(x))
@property @property
def domain(self): def domain(self):
return self._domain return self._domain
...@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator): ...@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
def unitary(self): def unitary(self):
return False return False
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'not' distribution_strategy = 'not'
# Set up position space # Set up position space
s_space = RGSpace([128,128]) s_space = RGSpace([128, 128])
# s_space = HPSpace(32) # s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space # Define harmonic transformation and associated harmonic space
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
# Setting up power space # Set up power space
p_space = PowerSpace(h_space, logarithmic=True, p_space = PowerSpace(h_space, logarithmic=True,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# Choosing the prior correlation structure and defining correlation operator # Choose the prior correlation structure and defining correlation operator
p_spec = (lambda k: (.5 / (k + 1) ** 3)) p_spec = (lambda k: (.5 / (k + 1) ** 3))
S = create_power_operator(h_space, power_spectrum=p_spec, S = create_power_operator(h_space, power_spectrum=p_spec,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# Drawing a sample sh from the prior distribution in harmonic space # Draw a sample sh from the prior distribution in harmonic space
sp = Field(p_space, val=p_spec, sp = Field(p_space, val=p_spec,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
sh = sp.power_synthesize(real_signal=True) sh = sp.power_synthesize(real_signal=True)
# Choose the measurement instrument
# Choosing the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01) # Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument = DiagonalOperator(s_space, diagonal=1.) Instrument = DiagonalOperator(s_space, diagonal=1.)
# Instrument._diagonal.val[200:400, 200:400] = 0 # Instrument._diagonal.val[200:400, 200:400] = 0
#Instrument._diagonal.val[64:512-64, 64:512-64] = 0 # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
#Adding a harmonic transformation to the instrument # Add a harmonic transformation to the instrument
R = AdjointFFTResponse(fft, Instrument) R = AdjointFFTResponse(fft, Instrument)
noise = 1. noise = 1.
...@@ -92,7 +97,7 @@ if __name__ == "__main__": ...@@ -92,7 +97,7 @@ if __name__ == "__main__":
std=sqrt(noise), std=sqrt(noise),
mean=0) mean=0)
# Creating the mock data # Create mock data
d = R(sh) + n d = R(sh) + n
# The information source # The information source
...@@ -103,52 +108,49 @@ if __name__ == "__main__": ...@@ -103,52 +108,49 @@ if __name__ == "__main__":
if rank == 0: if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html') pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# minimization strategy # Minimization strategy
def convergence_measure(a_energy, iteration): # returns current energy
def convergence_measure(a_energy, iteration): # returns current energy
x = a_energy.value x = a_energy.value
print (x, iteration) print(x, iteration)
minimizer1 = RelaxedNewton(convergence_tolerance=1e-4,
minimizer1 = RelaxedNewton(convergence_tolerance=1e-2, convergence_level=1,
convergence_level=2, iteration_limit=5,
iteration_limit=3, callback=convergence_measure)
callback=convergence_measure) minimizer2 = VL_BFGS(convergence_tolerance=1e-4,
convergence_level=1,
minimizer2 = VL_BFGS(convergence_tolerance=0, iteration_limit=20,
iteration_limit=7, callback=convergence_measure,
callback=convergence_measure, max_history_length=20)
max_history_length=3) minimizer3 = SteepestDescent(convergence_tolerance=1e-4,
iteration_limit=100,
# Setting starting position callback=convergence_measure)
flat_power = Field(p_space,val=1e-8)
# Set starting position
flat_power = Field(p_space, val=1e-8)
m0 = flat_power.power_synthesize(real_signal=True) m0 = flat_power.power_synthesize(real_signal=True)
t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2)) t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2))
for i in range(500): for i in range(500):
S0 = create_power_operator(h_space, power_spectrum=exp(t0), S0 = create_power_operator(h_space, power_spectrum=exp(t0),
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# Initializing the nonlinear Wiener Filter energy # Initialize non-linear Wiener Filter energy
map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0) map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
# Solving the Wiener Filter analytically # Solve the Wiener Filter analytically
D0 = map_energy.curvature D0 = map_energy.curvature
m0 = D0.inverse_times(j) m0 = D0.inverse_times(j)
# Initializing the power energy with updated parameters # Initialize power energy with updated parameters
power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0, smoothness_prior=10., samples=3) power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0,
smoothness_prior=10., samples=3)
(power_energy, convergence) = minimizer1(power_energy)
# Setting new power spectrum
t0.val = power_energy.position.val.real
# Plotting current estimate (power_energy, convergence) = minimizer2(power_energy)
print i
if i%50 == 0:
plot_parameters(m0,t0,log(sp), data_power)
# Set new power spectrum
t0.val = power_energy.position.val.real
# Plot current estimate
print(i)
if i % 5 == 0:
plot_parameters(m0, t0, log(sp), data_power)
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy from __future__ import print_function
class LineEnergy(Energy): class LineEnergy(object):
""" Evaluates an underlying Energy along a certain line direction. """ Evaluates an underlying Energy along a certain line direction.
Given an Energy class and a line direction, its position is parametrized by Given an Energy class and a line direction, its position is parametrized by
...@@ -27,34 +27,31 @@ class LineEnergy(Energy): ...@@ -27,34 +27,31 @@ class LineEnergy(Energy):
Parameters Parameters
---------- ----------
position : float line_position : float
The step length parameter along the given line direction. Defines the full spatial position of this energy via
self.energy.position = zero_point + line_position*line_direction
energy : Energy energy : Energy
The Energy object which will be evaluated along the given direction. The Energy object which will be evaluated along the given direction.
line_direction : Field line_direction : Field
Direction used for line evaluation. Direction used for line evaluation. Does not have to be normalized.
zero_point : Field *optional* offset : float *optional*
Fixing the zero point of the line restriction. Used to memorize this Indirectly defines the zero point of the line via the equation
position in new initializations. By the default the current position energy.position = zero_point + offset*line_direction
of the supplied `energy` instance is used (default : None). (default : 0.).
Attributes Attributes
---------- ----------
position : float line_position : float
The position along the given line direction relative to the zero point. The position along the given line direction relative to the zero point.
value : float value : float
The value of the energy functional at given `position`. The value of the energy functional at the given position
gradient : float directional_derivative : float
The gradient of the underlying energy instance along the line direction The directional derivative at the given position
projected on the line direction.
curvature : callable
A positive semi-definite operator or function describing the curvature
of the potential at given `position`.
line_direction : Field line_direction : Field
Direction along which the movement is restricted. Does not have to be Direction along which the movement is restricted. Does not have to be
normalized. normalized.
energy : Energy energy : Energy
The underlying Energy at the `position` along the line direction. The underlying Energy at the given position
Raises Raises
------ ------
...@@ -72,45 +69,49 @@ class LineEnergy(Energy): ...@@ -72,45 +69,49 @@ class LineEnergy(Energy):
""" """
def __init__(self, position, energy, line_direction, zero_point=None): def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__(position=position) self._line_position = float(line_position)
self.line_direction = line_direction self._line_direction = line_direction
if zero_point is None: pos = energy.position \
zero_point = energy.position + (self._line_position-float(offset))*self._line_direction
self._zero_point = zero_point self.energy = energy.at(position=pos)
position_on_line = self._zero_point + self.position*line_direction def at(self, line_position):
self.energy = energy.at(position=position_on_line)
def at(self, position):
""" Returns LineEnergy at new position, memorizing the zero point. """ Returns LineEnergy at new position, memorizing the zero point.
Parameters Parameters
---------- ----------
position : float line_position : float
Parameter for the new position on the line direction. Parameter for the new position on the line direction.
Returns Returns
------- -------
out : LineEnergy
LineEnergy object at new position with same zero point as `self`. LineEnergy object at new position with same zero point as `self`.
""" """
return self.__class__(position, return self.__class__(line_position,
self.energy, self.energy,
self.line_direction, self.line_direction,
zero_point=self._zero_point) offset=self.line_position)
@property @property
def value(self): def value(self):
return self.energy.value return self.energy.value
@property @property
def gradient(self): def line_position(self):
return self.energy.gradient.vdot(self.line_direction) return self._line_position
@property
def line_direction(self):
return self._line_direction
@property @property
def curvature(self): def directional_derivative(self):
return self.energy.curvature res = self.energy.gradient.vdot(self.line_direction)
if abs(res.imag) / max(abs(res.real), 1.) > 1e-12:
print ("directional derivative has non-negligible "
"imaginary part:", res)
return res.real
...@@ -116,7 +116,6 @@ class Field(Loggable, Versionable, object): ...@@ -116,7 +116,6 @@ class Field(Loggable, Versionable, object):
def __init__(self, domain=None, val=None, dtype=None, def __init__(self, domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False): distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
...@@ -132,6 +131,7 @@ class Field(Loggable, Versionable, object): ...@@ -132,6 +131,7 @@ class Field(Loggable, Versionable, object):
else: else:
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain, val=None): def _parse_domain(self, domain, val=None):
if domain is None: if domain is None:
if isinstance(val, Field): if isinstance(val, Field):
...@@ -672,7 +672,7 @@ class Field(Loggable, Versionable, object): ...@@ -672,7 +672,7 @@ class Field(Loggable, Versionable, object):
result_list[0].domain_axes[power_space_index]) result_list[0].domain_axes[power_space_index])
if pindex.distribution_strategy is not local_distribution_strategy: if pindex.distribution_strategy is not local_distribution_strategy:
self.logger.warn( raise AttributeError(
"The distribution_strategy of pindex does not fit the " "The distribution_strategy of pindex does not fit the "
"slice_local distribution strategy of the synthesized field.") "slice_local distribution strategy of the synthesized field.")
...@@ -779,14 +779,14 @@ class Field(Loggable, Versionable, object): ...@@ -779,14 +779,14 @@ class Field(Loggable, Versionable, object):
dim dim
""" """
if not hasattr(self, '_shape'):
shape_tuple = tuple(sp.shape for sp in self.domain) shape_tuple = tuple(sp.shape for sp in self.domain)
try: try:
global_shape = reduce(lambda x, y: x + y, shape_tuple) global_shape = reduce(lambda x, y: x + y, shape_tuple)
except TypeError: except TypeError:
global_shape = () global_shape = ()
self._shape = global_shape
return global_shape return self._shape
@property @property
def dim(self): def dim(self):
......
...@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, theta, T, inverter=None, preconditioner=None): def __init__(self, theta, T, inverter=None, preconditioner=None, **kwargs):
self.theta = DiagonalOperator(theta.domain, diagonal=theta) self.theta = DiagonalOperator(theta.domain, diagonal=theta)
self.T = T self.T = T
...@@ -30,7 +30,8 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -30,7 +30,8 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self._domain = self.theta.domain self._domain = self.theta.domain
super(CriticalPowerCurvature, self).__init__( super(CriticalPowerCurvature, self).__init__(
inverter=inverter, inverter=inverter,
preconditioner=preconditioner) preconditioner=preconditioner,
**kwargs)
def _times(self, x, spaces): def _times(self, x, spaces):
return self.T(x) + self.theta(x) return self.T(x) + self.theta(x)
......
...@@ -22,7 +22,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -22,7 +22,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
""" """
def __init__(self, R, N, S, inverter=None, preconditioner=None): def __init__(self, R, N, S, inverter=None, preconditioner=None, **kwargs):
self.R = R self.R = R
self.N = N self.N = N
...@@ -32,7 +32,8 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -32,7 +32,8 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self._domain = self.S.domain self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__( super(WienerFilterCurvature, self).__init__(
inverter=inverter, inverter=inverter,
preconditioner=preconditioner) preconditioner=preconditioner,
**kwargs)
@property @property
def domain(self): def domain(self):
......
...@@ -23,7 +23,7 @@ class WienerFilterEnergy(Energy): ...@@ -23,7 +23,7 @@ class WienerFilterEnergy(Energy):
The prior signal covariance in harmonic space. The prior signal covariance in harmonic space.
""" """
def __init__(self, position, d, R, N, S, inverter=None): def __init__(self, position, d, R, N, S):
super(WienerFilterEnergy, self).__init__(position=position) super(WienerFilterEnergy, self).__init__(position=position)
self.d = d self.d = d
self.R = R self.R = R
...@@ -32,7 +32,7 @@ class WienerFilterEnergy(Energy): ...@@ -32,7 +32,7 @@ class WienerFilterEnergy(Energy):
def at(self, position): def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N, return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, inverter=self.inverter) S=self.S)
@property @property
@memo @memo
...@@ -49,6 +49,7 @@ class WienerFilterEnergy(Energy): ...@@ -49,6 +49,7 @@ class WienerFilterEnergy(Energy):
def curvature(self): def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S) return WienerFilterCurvature(R=self.R, N=self.N, S=self.S)
@property
@memo @memo
def _Dx(self): def _Dx(self):
return self.curvature(self.position) return self.curvature(self.position)
......
...@@ -137,7 +137,7 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje ...@@ -137,7 +137,7 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje
# compute the the gradient for the current location # compute the the gradient for the current location
gradient = energy.gradient gradient = energy.gradient
gradient_norm = gradient.vdot(gradient) gradient_norm = gradient.norm()
# check if position is at a flat point # check if position is at a flat point
if gradient_norm == 0: if gradient_norm == 0:
...@@ -147,7 +147,6 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje ...@@ -147,7 +147,6 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje
# current position is encoded in energy object # current position is encoded in energy object
descent_direction = self.get_descent_direction(energy)