Commit 25c0b11c authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'master' into 'mpitests'

Master

See merge request !174
parents 1a0ebf7e 04af1dae
Pipeline #15665 failed with stage
in 17 minutes and 2 seconds
from nifty import * import numpy as np
from nifty.library.wiener_filter import WienerFilterEnergy from nifty import (VL_BFGS, DiagonalOperator, FFTOperator, Field,
LinearOperator, PowerSpace, RelaxedNewton, RGSpace,
SteepestDescent, create_power_operator, exp, log, sqrt)
from nifty.library.critical_filter import CriticalPowerEnergy from nifty.library.critical_filter import CriticalPowerEnergy
import plotly.offline as pl from nifty.library.wiener_filter import WienerFilterEnergy
import plotly.graph_objs as go
import plotly.graph_objs as go
import plotly.offline as pl
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
rank = comm.rank rank = comm.rank
np.random.seed(42) np.random.seed(42)
def plot_parameters(m,t,p, p_d): def plot_parameters(m, t, p, p_d):
x = log(t.domain[0].kindex) x = log(t.domain[0].kindex)
m = fft.adjoint_times(m) m = fft.adjoint_times(m)
...@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d): ...@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
p = p.val.get_full_data().real p = p.val.get_full_data().real
p_d = p_d.val.get_full_data().real p_d = p_d.val.get_full_data().real
pl.plot([go.Heatmap(z=m)], filename='map.html') pl.plot([go.Heatmap(z=m)], filename='map.html')
pl.plot([go.Scatter(x=x,y=t), go.Scatter(x=x ,y=p), go.Scatter(x=x, y=p_d)], filename="t.html") pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
go.Scatter(x=x, y=p_d)], filename="t.html")
class AdjointFFTResponse(LinearOperator): class AdjointFFTResponse(LinearOperator):
...@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator): ...@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
def _adjoint_times(self, x, spaces=None): def _adjoint_times(self, x, spaces=None):
return self.FFT(self.R.adjoint_times(x)) return self.FFT(self.R.adjoint_times(x))
@property @property
def domain(self): def domain(self):
return self._domain return self._domain
...@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator): ...@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
def unitary(self): def unitary(self):
return False return False
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'not' distribution_strategy = 'not'
# Set up position space # Set up position space
s_space = RGSpace([128,128]) s_space = RGSpace([128, 128])
# s_space = HPSpace(32) # s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space # Define harmonic transformation and associated harmonic space
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
# Setting up power space # Set up power space
p_space = PowerSpace(h_space, logarithmic=True, p_space = PowerSpace(h_space, logarithmic=True,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# Choosing the prior correlation structure and defining correlation operator # Choose the prior correlation structure and defining correlation operator
p_spec = (lambda k: (.5 / (k + 1) ** 3)) p_spec = (lambda k: (.5 / (k + 1) ** 3))
S = create_power_operator(h_space, power_spectrum=p_spec, S = create_power_operator(h_space, power_spectrum=p_spec,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# Drawing a sample sh from the prior distribution in harmonic space # Draw a sample sh from the prior distribution in harmonic space
sp = Field(p_space, val=p_spec, sp = Field(p_space, val=p_spec,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
sh = sp.power_synthesize(real_signal=True) sh = sp.power_synthesize(real_signal=True)
# Choose the measurement instrument
# Choosing the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01) # Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument = DiagonalOperator(s_space, diagonal=1.) Instrument = DiagonalOperator(s_space, diagonal=1.)
# Instrument._diagonal.val[200:400, 200:400] = 0 # Instrument._diagonal.val[200:400, 200:400] = 0
#Instrument._diagonal.val[64:512-64, 64:512-64] = 0 # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
# Add a harmonic transformation to the instrument
#Adding a harmonic transformation to the instrument
R = AdjointFFTResponse(fft, Instrument) R = AdjointFFTResponse(fft, Instrument)
noise = 1. noise = 1.
...@@ -92,7 +97,7 @@ if __name__ == "__main__": ...@@ -92,7 +97,7 @@ if __name__ == "__main__":
std=sqrt(noise), std=sqrt(noise),
mean=0) mean=0)
# Creating the mock data # Create mock data
d = R(sh) + n d = R(sh) + n
# The information source # The information source
...@@ -103,52 +108,49 @@ if __name__ == "__main__": ...@@ -103,52 +108,49 @@ if __name__ == "__main__":
if rank == 0: if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html') pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# minimization strategy # Minimization strategy
def convergence_measure(a_energy, iteration): # returns current energy def convergence_measure(a_energy, iteration): # returns current energy
x = a_energy.value x = a_energy.value
print (x, iteration) print(x, iteration)
minimizer1 = RelaxedNewton(convergence_tolerance=1e-2, minimizer1 = RelaxedNewton(convergence_tolerance=1e-4,
convergence_level=2, convergence_level=1,
iteration_limit=3, iteration_limit=5,
callback=convergence_measure) callback=convergence_measure)
minimizer2 = VL_BFGS(convergence_tolerance=1e-4,
minimizer2 = VL_BFGS(convergence_tolerance=0, convergence_level=1,
iteration_limit=7, iteration_limit=20,
callback=convergence_measure, callback=convergence_measure,
max_history_length=3) max_history_length=20)
minimizer3 = SteepestDescent(convergence_tolerance=1e-4,
iteration_limit=100,
callback=convergence_measure)
# Setting starting position # Set starting position
flat_power = Field(p_space,val=1e-8) flat_power = Field(p_space, val=1e-8)
m0 = flat_power.power_synthesize(real_signal=True) m0 = flat_power.power_synthesize(real_signal=True)
t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2)) t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2))
for i in range(500): for i in range(500):
S0 = create_power_operator(h_space, power_spectrum=exp(t0), S0 = create_power_operator(h_space, power_spectrum=exp(t0),
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# Initializing the nonlinear Wiener Filter energy # Initialize non-linear Wiener Filter energy
map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0) map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
# Solving the Wiener Filter analytically # Solve the Wiener Filter analytically
D0 = map_energy.curvature D0 = map_energy.curvature
m0 = D0.inverse_times(j) m0 = D0.inverse_times(j)
# Initializing the power energy with updated parameters # Initialize power energy with updated parameters
power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0, smoothness_prior=10., samples=3) power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0,
smoothness_prior=10., samples=3)
(power_energy, convergence) = minimizer1(power_energy)
(power_energy, convergence) = minimizer2(power_energy)
# Setting new power spectrum # Set new power spectrum
t0.val = power_energy.position.val.real t0.val = power_energy.position.val.real
# Plotting current estimate # Plot current estimate
print i print(i)
if i%50 == 0: if i % 5 == 0:
plot_parameters(m0,t0,log(sp), data_power) plot_parameters(m0, t0, log(sp), data_power)
...@@ -16,10 +16,8 @@ ...@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy
class LineEnergy(object):
class LineEnergy(Energy):
""" Evaluates an underlying Energy along a certain line direction. """ Evaluates an underlying Energy along a certain line direction.
Given an Energy class and a line direction, its position is parametrized by Given an Energy class and a line direction, its position is parametrized by
...@@ -27,34 +25,31 @@ class LineEnergy(Energy): ...@@ -27,34 +25,31 @@ class LineEnergy(Energy):
Parameters Parameters
---------- ----------
position : float line_position : float
The step length parameter along the given line direction. Defines the full spatial position of this energy via
self.energy.position = zero_point + line_position*line_direction
energy : Energy energy : Energy
The Energy object which will be evaluated along the given direction. The Energy object which will be evaluated along the given direction.
line_direction : Field line_direction : Field
Direction used for line evaluation. Direction used for line evaluation. Does not have to be normalized.
zero_point : Field *optional* offset : float *optional*
Fixing the zero point of the line restriction. Used to memorize this Indirectly defines the zero point of the line via the equation
position in new initializations. By the default the current position energy.position = zero_point + offset*line_direction
of the supplied `energy` instance is used (default : None). (default : 0.).
Attributes Attributes
---------- ----------
position : float line_position : float
The position along the given line direction relative to the zero point. The position along the given line direction relative to the zero point.
value : float value : float
The value of the energy functional at given `position`. The value of the energy functional at the given position
gradient : float directional_derivative : float
The gradient of the underlying energy instance along the line direction The directional derivative at the given position
projected on the line direction.
curvature : callable
A positive semi-definite operator or function describing the curvature
of the potential at given `position`.
line_direction : Field line_direction : Field
Direction along which the movement is restricted. Does not have to be Direction along which the movement is restricted. Does not have to be
normalized. normalized.
energy : Energy energy : Energy
The underlying Energy at the `position` along the line direction. The underlying Energy at the given position
Raises Raises
------ ------
...@@ -72,45 +67,49 @@ class LineEnergy(Energy): ...@@ -72,45 +67,49 @@ class LineEnergy(Energy):
""" """
def __init__(self, position, energy, line_direction, zero_point=None): def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__(position=position) self._line_position = float(line_position)
self.line_direction = line_direction self._line_direction = line_direction
if zero_point is None:
zero_point = energy.position
self._zero_point = zero_point
position_on_line = self._zero_point + self.position*line_direction pos = energy.position \
self.energy = energy.at(position=position_on_line) + (self._line_position-float(offset))*self._line_direction
self.energy = energy.at(position=pos)
def at(self, position): def at(self, line_position):
""" Returns LineEnergy at new position, memorizing the zero point. """ Returns LineEnergy at new position, memorizing the zero point.
Parameters Parameters
---------- ----------
position : float line_position : float
Parameter for the new position on the line direction. Parameter for the new position on the line direction.
Returns Returns
------- -------
out : LineEnergy
LineEnergy object at new position with same zero point as `self`. LineEnergy object at new position with same zero point as `self`.
""" """
return self.__class__(position, return self.__class__(line_position,
self.energy, self.energy,
self.line_direction, self.line_direction,
zero_point=self._zero_point) offset=self.line_position)
@property @property
def value(self): def value(self):
return self.energy.value return self.energy.value
@property @property
def gradient(self): def line_position(self):
return self.energy.gradient.vdot(self.line_direction) return self._line_position
@property
def line_direction(self):
return self._line_direction
@property @property
def curvature(self): def directional_derivative(self):
return self.energy.curvature res = self.energy.gradient.vdot(self.line_direction)
if abs(res.imag) / max(abs(res.real), 1.) > 1e-12:
print "directional derivative has non-negligible " \
"imaginary part:", res
return res.real
...@@ -112,7 +112,6 @@ class Field(Loggable, Versionable, object): ...@@ -112,7 +112,6 @@ class Field(Loggable, Versionable, object):
def __init__(self, domain=None, val=None, dtype=None, def __init__(self, domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False): distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
...@@ -128,6 +127,7 @@ class Field(Loggable, Versionable, object): ...@@ -128,6 +127,7 @@ class Field(Loggable, Versionable, object):
else: else:
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain, val=None): def _parse_domain(self, domain, val=None):
if domain is None: if domain is None:
if isinstance(val, Field): if isinstance(val, Field):
...@@ -466,7 +466,7 @@ class Field(Loggable, Versionable, object): ...@@ -466,7 +466,7 @@ class Field(Loggable, Versionable, object):
return result_obj return result_obj
def power_synthesize(self, spaces=None, real_power=True, real_signal=True, def power_synthesize(self, spaces=None, real_power=True, real_signal=True,
mean=None, std=None): mean=None, std=None, distribution_strategy=None):
""" Yields a sampled field with `self`**2 as its power spectrum. """ Yields a sampled field with `self`**2 as its power spectrum.
This method draws a Gaussian random field in the harmonic partner This method draws a Gaussian random field in the harmonic partner
...@@ -541,13 +541,16 @@ class Field(Loggable, Versionable, object): ...@@ -541,13 +541,16 @@ class Field(Loggable, Versionable, object):
else: else:
result_list = [None, None] result_list = [None, None]
if distribution_strategy is None:
distribution_strategy = gc['default_distribution_strategy']
result_list = [self.__class__.from_random( result_list = [self.__class__.from_random(
'normal', 'normal',
mean=mean, mean=mean,
std=std, std=std,
domain=result_domain, domain=result_domain,
dtype=np.complex, dtype=np.complex,
distribution_strategy=self.distribution_strategy) distribution_strategy=distribution_strategy)
for x in result_list] for x in result_list]
# from now on extract the values from the random fields for further # from now on extract the values from the random fields for further
...@@ -609,39 +612,47 @@ class Field(Loggable, Versionable, object): ...@@ -609,39 +612,47 @@ class Field(Loggable, Versionable, object):
# correct variance # correct variance
if preserve_gaussian_variance: if preserve_gaussian_variance:
assert issubclass(val.dtype.type, np.complexfloating),\
"complex input field is needed here"
h *= np.sqrt(2) h *= np.sqrt(2)
a *= np.sqrt(2) a *= np.sqrt(2)
if not issubclass(val.dtype.type, np.complexfloating): # The code below should not be needed in practice, since it would
# in principle one must not correct the variance for the fixed # only ever be called when hermitianizing a purely real field.
# points of the hermitianization. However, for a complex field # However it might be of educational use and keep us from forgetting
# the input field loses half of its power at its fixed points # how these things are done ...
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary! # if not issubclass(val.dtype.type, np.complexfloating):
# => The hermitianization can be done on a space level since # # in principle one must not correct the variance for the fixed
# either nothing must be done (LMSpace) or ALL points need a # # points of the hermitianization. However, for a complex field
# factor of sqrt(2) # # the input field loses half of its power at its fixed points
# => use the preserve_gaussian_variance flag in the # # in the `hermitian` part. Hence, here a factor of sqrt(2) is
# hermitian_decomposition method above. # # also necessary!
# # => The hermitianization can be done on a space level since
# This code is for educational purposes: # # either nothing must be done (LMSpace) or ALL points need a
fixed_points = [domain[i].hermitian_fixed_points() # # factor of sqrt(2)
for i in spaces] # # => use the preserve_gaussian_variance flag in the
fixed_points = [[fp] if fp is None else fp # # hermitian_decomposition method above.
for fp in fixed_points] #
# # This code is for educational purposes:
for product_point in itertools.product(*fixed_points): # fixed_points = [domain[i].hermitian_fixed_points()
slice_object = np.array((slice(None), )*len(val.shape), # for i in spaces]
dtype=np.object) # fixed_points = [[fp] if fp is None else fp
for i, sp in enumerate(spaces): # for fp in fixed_points]
point_component = product_point[i] #
if point_component is None: # for product_point in itertools.product(*fixed_points):
point_component = slice(None) # slice_object = np.array((slice(None), )*len(val.shape),
slice_object[list(domain_axes[sp])] = point_component # dtype=np.object)
# for i, sp in enumerate(spaces):
slice_object = tuple(slice_object) # point_component = product_point[i]
h[slice_object] /= np.sqrt(2) # if point_component is None:
a[slice_object] /= np.sqrt(2) # point_component = slice(None)
# slice_object[list(domain_axes[sp])] = point_component
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return (h, a) return (h, a)
def _spec_to_rescaler(self, spec, result_list, power_space_index): def _spec_to_rescaler(self, spec, result_list, power_space_index):
...@@ -657,7 +668,7 @@ class Field(Loggable, Versionable, object): ...@@ -657,7 +668,7 @@ class Field(Loggable, Versionable, object):
result_list[0].domain_axes[power_space_index]) result_list[0].domain_axes[power_space_index])
if pindex.distribution_strategy is not local_distribution_strategy: if pindex.distribution_strategy is not local_distribution_strategy:
self.logger.warn( raise AttributeError(
"The distribution_strategy of pindex does not fit the " "The distribution_strategy of pindex does not fit the "
"slice_local distribution strategy of the synthesized field.") "slice_local distribution strategy of the synthesized field.")
...@@ -764,14 +775,14 @@ class Field(Loggable, Versionable, object): ...@@ -764,14 +775,14 @@ class Field(Loggable, Versionable, object):
dim dim
""" """
if not hasattr(self, '_shape'):
shape_tuple = tuple(sp.shape for sp in self.domain) shape_tuple = tuple(sp.shape for sp in self.domain)
try: try:
global_shape = reduce(lambda x, y: x + y, shape_tuple) global_shape = reduce(lambda x, y: x + y, shape_tuple)
except TypeError: except TypeError:
global_shape = () global_shape = ()
self._shape = global_shape
return global_shape return self._shape
@property @property
def dim(self): def dim(self):
......
...@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, theta, T, inverter=None, preconditioner=None): def __init__(self, theta, T, inverter=None, preconditioner=None, **kwargs):
self.theta = DiagonalOperator(theta.domain, diagonal=theta) self.theta = DiagonalOperator(theta.domain, diagonal=theta)
self.T = T self.T = T
...@@ -30,7 +30,8 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -30,7 +30,8 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self._domain = self.theta.domain self._domain = self.theta.domain
super(CriticalPowerCurvature, self).__init__( super(CriticalPowerCurvature, self).__init__(
inverter=inverter, inverter=inverter,
preconditioner=preconditioner)