Commit 92049714 authored by Martin Reinecke's avatar Martin Reinecke

more adjustments

parent 751ba2c3
......@@ -3,12 +3,9 @@
import nifty as ift
from nifty import plotting
import numpy as np
from keepers import Repository
if __name__ == "__main__":
ift.nifty_configuration['default_distribution_strategy'] = 'fftw'
# Setting up parameters |\label{code:wf_parameters}|
correlation_length_scale = 1. # Typical distance over which the field is correlated
fluctuation_scale = 2. # Variance of field in position space
......@@ -24,7 +21,7 @@ if __name__ == "__main__":
N_pixels = 512 # Grid resolution (pixels per axis)
signal_space = ift.RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
harmonic_space = ift.FFTOperator.get_default_codomain(signal_space)
fft = ift.FFTOperator(harmonic_space, target=signal_space, target_dtype=np.float)
fft = ift.FFTOperator(harmonic_space, target=signal_space)
power_space = ift.PowerSpace(harmonic_space)
# Creating the mock signal |\label{code:wf_mock_signal}|
......@@ -48,25 +45,20 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic)
ctrl = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl,preconditioner=S.times)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic,inverter=inverter)
m_k = wiener_curvature.inverse_times(j) #|\label{code:wf_wiener_filter}|
m = fft(m_k)
# Probing the uncertainty |\label{code:wf_uncertainty_probing}|
class Proby(ift.DiagonalProberMixin, ift.Prober): pass
proby = Proby(signal_space, probe_count=800)
proby = Proby(signal_space, probe_count=20)
proby(lambda z: fft(wiener_curvature.inverse_times(fft.inverse_times(z)))) #|\label{code:wf_variance_fft_wrap}|
sm = ift.SmoothingOperator.make(signal_space, sigma=0.03)
sm = ift.FFTSmoothingOperator(signal_space, sigma=0.03)
variance = ift.sqrt(sm(proby.diagonal.weight(-1))) #|\label{code:wf_variance_weighting}|
repo = Repository('repo_800.h5')
repo.add(mock_signal, 'mock_signal')
repo.add(data, 'data')
repo.add(m, 'm')
repo.add(variance, 'variance')
repo.commit()
# Plotting #|\label{code:wf_plotting}|
plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
plotter.figure.xaxis = ift.plotting.Axis(label='Pixel Index')
......
......@@ -2,15 +2,13 @@ import numpy as np
from nifty import RGSpace, PowerSpace, Field, FFTOperator, ComposedOperator,\
DiagonalOperator, ResponseOperator, plotting,\
create_power_operator, nifty_configuration
create_power_operator
from nifty.library import WienerFilterCurvature
import nifty as ift
if __name__ == "__main__":
nifty_configuration['default_distribution_strategy'] = 'fftw'
nifty_configuration['harmonic_rg_base'] = 'real'
# Setting up variable parameters
# Typical distance over which the field is correlated
......@@ -46,8 +44,7 @@ if __name__ == "__main__":
mock_power = Field(power_space, val=power_spectrum)
np.random.seed(43)
mock_harmonic = mock_power.power_synthesize(real_signal=True)
if nifty_configuration['harmonic_rg_base'] == 'real':
mock_harmonic = mock_harmonic.real
mock_harmonic = mock_harmonic.real
mock_signal = fft(mock_harmonic)
R = ResponseOperator(signal_space, sigma=(response_sigma,))
......@@ -66,7 +63,9 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic)
ctrl = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=1e-2)
inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
m = wiener_curvature.inverse_times(j)
m_s = fft(m)
......@@ -77,6 +76,6 @@ if __name__ == "__main__":
plotter.path = 'data.html'
plotter(Field(
signal_space,
val=data.val.get_full_data().real.reshape(signal_space.shape)))
val=data.val.real.reshape(signal_space.shape)))
plotter.path = 'map.html'
plotter(m_s.real)
......@@ -28,7 +28,7 @@ __all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
def _math_helper(x, function):
if isinstance(x, Field):
result_val = x.val.apply_scalar_function(function)
result_val = function(x.val)
result = x.copy_empty(dtype=result_val.dtype)
result.val = result_val
else:
......
from ...operators import EndomorphicOperator,\
InvertibleOperatorMixin
class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
"""The curvature of the WienerFilterEnergy.
......@@ -22,17 +21,14 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
"""
def __init__(self, R, N, S, inverter=None, preconditioner=None, **kwargs):
def __init__(self, R, N, S, inverter, **kwargs):
self.R = R
self.N = N
self.S = S
if preconditioner is None:
preconditioner = self.S.times
self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__(
inverter=inverter,
preconditioner=preconditioner,
**kwargs)
@property
......
......@@ -18,7 +18,6 @@
from builtins import object
from ...energies import QuadraticEnergy
from ...minimization import ConjugateGradient
from ...field import Field
......@@ -36,22 +35,12 @@ class InvertibleOperatorMixin(object):
----------
inverter : Inverter
An instance of an Inverter class.
(default: ConjugateGradient)
preconditioner : LinearOperator
Preconditioner that is used by ConjugateGradient if no minimizer was
given.
"""
def __init__(self, inverter=None, preconditioner=None,
def __init__(self, inverter,
forward_x0=None, backward_x0=None, *args, **kwargs):
self.__preconditioner = preconditioner
if inverter is not None:
self.__inverter = inverter
else:
self.__inverter = ConjugateGradient(
preconditioner=self.__preconditioner)
self.__inverter = inverter
self.__forward_x0 = forward_x0
self.__backward_x0 = backward_x0
......@@ -61,8 +50,7 @@ class InvertibleOperatorMixin(object):
if self.__forward_x0 is not None:
x0 = self.__forward_x0
else:
x0 = Field(self.target, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.inverse_times,
......@@ -74,8 +62,7 @@ class InvertibleOperatorMixin(object):
if self.__backward_x0 is not None:
x0 = self.__backward_x0
else:
x0 = Field(self.domain, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.adjoint_inverse_times,
......@@ -87,8 +74,7 @@ class InvertibleOperatorMixin(object):
if self.__backward_x0 is not None:
x0 = self.__backward_x0
else:
x0 = Field(self.domain, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.times,
......@@ -100,8 +86,7 @@ class InvertibleOperatorMixin(object):
if self.__forward_x0 is not None:
x0 = self.__forward_x0
else:
x0 = Field(self.target, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.adjoint_times,
......
......@@ -81,20 +81,19 @@ class PlotterBase(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object)
for field in fields]
# create plots
if rank == 0:
plots_list = []
for slice_list in utilities.get_slice_list(data_list[0].shape,
axes):
plots_list += \
[[self.plot.at(self._parse_data(current_data,
field,
spaces))
for (current_data, field) in zip(data_list, fields)]]
plots_list = []
for slice_list in utilities.get_slice_list(data_list[0].shape,
axes):
plots_list += \
[[self.plot.at(self._parse_data(current_data,
field,
spaces))
for (current_data, field) in zip(data_list, fields)]]
figures = [self.figure.at(plots, title=title)
for plots in plots_list]
figures = [self.figure.at(plots, title=title)
for plots in plots_list]
self._finalize_figure(figures, path=path)
self._finalize_figure(figures, path=path)
def _get_data_from_field(self, field, spaces, data_extractor):
for i, space_index in enumerate(spaces):
......@@ -104,7 +103,7 @@ class PlotterBase(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object)
"not match the plotters domain.")
# TODO: add data_extractor functionality here
data = field.val.get_full_data(target_rank=0)
data = field.val
return data
@abc.abstractmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment