Commit 114d5196 authored by Martin Reinecke's avatar Martin Reinecke

nor more keepers

parent 92049714
......@@ -18,6 +18,7 @@ before_script:
test_min:
stage: test
script:
- ci/install_pyHealpix.sh
- nosetests
- nosetests -x --with-coverage --cover-package=nifty --cover-branches
- >
......
......@@ -4,5 +4,4 @@ nose
parameterized
coverage
git+https://gitlab.mpcdf.mpg.de/ift/mpi_dummy.git
git+https://gitlab.mpcdf.mpg.de/ift/keepers.git
pyfftw
......@@ -4,14 +4,9 @@ import numpy as np
import nifty as ift
from nifty import plotting
from keepers import Repository
if __name__ == "__main__":
ift.nifty_configuration['default_distribution_strategy'] = 'fftw'
signal_to_noise = 1.5 # The signal to noise ratio
# Setting up parameters |\label{code:wf_parameters}|
correlation_length_1 = 1. # Typical distance over which the field is correlated
field_variance_1 = 2. # Variance of field in position space
......@@ -28,12 +23,10 @@ if __name__ == "__main__":
signal_space_1 = ift.RGSpace([N_pixels_1], distances=L_1/N_pixels_1)
harmonic_space_1 = ift.FFTOperator.get_default_codomain(signal_space_1)
fft_1 = ift.FFTOperator(harmonic_space_1, target=signal_space_1,
domain_dtype=np.complex, target_dtype=np.complex)
fft_1 = ift.FFTOperator(harmonic_space_1, target=signal_space_1)
power_space_1 = ift.PowerSpace(harmonic_space_1)
mock_power_1 = ift.Field(power_space_1, val=power_spectrum_1,
distribution_strategy='not')
mock_power_1 = ift.Field(power_space_1, val=power_spectrum_1)
......@@ -53,19 +46,15 @@ if __name__ == "__main__":
signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2)
harmonic_space_2 = ift.FFTOperator.get_default_codomain(signal_space_2)
fft_2 = ift.FFTOperator(harmonic_space_2, target=signal_space_2,
domain_dtype=np.complex, target_dtype=np.complex)
power_space_2 = ift.PowerSpace(harmonic_space_2, distribution_strategy='not')
fft_2 = ift.FFTOperator(harmonic_space_2, target=signal_space_2)
power_space_2 = ift.PowerSpace(harmonic_space_2)
mock_power_2 = ift.Field(power_space_2, val=power_spectrum_2,
distribution_strategy='not')
mock_power_2 = ift.Field(power_space_2, val=power_spectrum_2)
fft = ift.ComposedOperator((fft_1, fft_2))
mock_power = ift.Field(domain=(power_space_1, power_space_2),
val=np.outer(mock_power_1.val.get_full_data(),
mock_power_2.val.get_full_data()),
distribution_strategy='not')
val=np.outer(mock_power_1.val, mock_power_2.val))
diagonal = mock_power.power_synthesize(spaces=(0, 1), mean=1, std=0,
real_signal=False)**2
......@@ -83,7 +72,7 @@ if __name__ == "__main__":
mask_1.val[N1_10*7:N1_10*9] = 0.
N2_10 = int(N_pixels_2/10)
mask_2 = ift.Field(signal_space_2, val=1., distribution_strategy='not')
mask_2 = ift.Field(signal_space_2, val=1.)
mask_2.val[N2_10*7:N2_10*9] = 0.
R = ift.ResponseOperator((signal_space_1, signal_space_2),
......@@ -102,8 +91,11 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic)
wiener_curvature._InvertibleOperatorMixin__inverter.convergence_tolerance = 1e-3
ctrl = ift.DefaultIterationController(verbose=True,
tol_abs_gradnorm=1.0,
tol_rel_gradnorm=1e-4)
inverter = ift.ConjugateGradient(controller=ctrl,preconditioner=S.times)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
m_k = wiener_curvature.inverse_times(j) #|\label{code:wf_wiener_filter}|
m = fft(m_k)
......@@ -116,13 +108,6 @@ if __name__ == "__main__":
# variance = sm(proby.diagonal.weight(-1))
variance = proby.diagonal.weight(-1)
repo = Repository('repo_100.h5')
repo.add(mock_signal, 'mock_signal')
repo.add(data, 'data')
repo.add(m, 'm')
repo.add(variance, 'variance')
repo.commit()
plot_space = ift.RGSpace((N_pixels_1, N_pixels_2))
plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
plotter.figure.xaxis = ift.plotting.Axis(label='Pixel Index')
......@@ -136,6 +121,6 @@ if __name__ == "__main__":
plotter.plot.zmin = np.real(mock_signal.min());
plotter.plot.zmax = np.real(mock_signal.max());
plotter(ift.Field(plot_space, val=mock_signal.val.real), path='mock_signal.html')
plotter(ift.Field(plot_space, val=data.val.get_full_data().real), path = 'data.html')
plotter(ift.Field(plot_space, val=data.val.real), path = 'data.html')
plotter(ift.Field(plot_space, val=m.val.real), path = 'map.html')
......@@ -53,7 +53,7 @@ if __name__ == "__main__":
# Probing the uncertainty |\label{code:wf_uncertainty_probing}|
class Proby(ift.DiagonalProberMixin, ift.Prober): pass
proby = Proby(signal_space, probe_count=20)
proby = Proby(signal_space, probe_count=800)
proby(lambda z: fft(wiener_curvature.inverse_times(fft.inverse_times(z)))) #|\label{code:wf_variance_fft_wrap}|
sm = ift.FFTSmoothingOperator(signal_space, sigma=0.03)
......
......@@ -20,14 +20,6 @@ from __future__ import division
from .version import __version__
# initialize the logger instance
from keepers import MPILogger
logger = MPILogger()
from .config import nifty_configuration
logger.logger.setLevel(nifty_configuration['loglevel'])
from .field import Field
from .random import Random
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .nifty_config import nifty_configuration
# -*- coding: utf-8 -*-
import os
__all__ = []
try:
import matplotlib
except ImportError:
pass
else:
try:
display = os.environ['DISPLAY']
except KeyError:
matplotlib.use('Agg')
else:
if display == '':
matplotlib.use('Agg')
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from . import matplotlib_init
import os
import numpy as np
import keepers
__all__ = ['nifty_configuration']
# Initialize the variables
variable_loglevel = keepers.Variable(
'loglevel',
[10],
lambda z: np.int(z) == z and 0 <= z <= 50,
genus='int')
nifty_configuration = keepers.get_Configuration(
name='NIFTy',
variables=[variable_loglevel],
file_name='NIFTy.conf',
search_paths=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/",
'./'])
########
try:
nifty_configuration.load()
except:
pass
......@@ -20,12 +20,11 @@ from __future__ import division
import abc
from .nifty_meta import NiftyMeta
from keepers import Loggable
from future.utils import with_metaclass
class DomainObject(with_metaclass(
NiftyMeta, type('NewBase', (Loggable, object), {}))):
NiftyMeta, type('NewBase', (object,), {}))):
"""The abstract class that can be used as a domain for a field.
This holds all the information and functionality a field needs to know
......@@ -42,8 +41,7 @@ class DomainObject(with_metaclass(
"""
def __init__(self):
# _global_id is used in the Versioning module from keepers
self._ignore_for_hash = ['_global_id']
self._ignore_for_hash = []
@abc.abstractmethod
def __repr__(self):
......
......@@ -19,11 +19,10 @@
from ..nifty_meta import NiftyMeta
from .memoization import memo
from keepers import Loggable
from future.utils import with_metaclass
class Energy(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
class Energy(with_metaclass(NiftyMeta, type('NewBase', (object,), {}))):
""" Provides the functional used by minimization schemes.
The Energy object is an implementation of a scalar function including its
......
......@@ -23,8 +23,6 @@ from builtins import range
import ast
import numpy as np
from keepers import Loggable
from .domain_object import DomainObject
from .spaces.power_space import PowerSpace
......@@ -34,7 +32,7 @@ from .random import Random
from functools import reduce
class Field(Loggable, object):
class Field(object):
""" The discrete representation of a continuous field over multiple spaces.
In NIFTY, Fields are used to store data arrays and carry all the needed
......@@ -269,7 +267,7 @@ class Field(Loggable, object):
# power_space instances
for sp in self.domain:
if not sp.harmonic and not isinstance(sp, PowerSpace):
self.logger.info(
raise TypeError(
"Field has a space in `domain` which is neither "
"harmonic nor a PowerSpace.")
......
......@@ -83,12 +83,10 @@ class ConjugateGradient(Minimizer):
q = energy.curvature(d)
ddotq = d.vdot(q).real
if ddotq==0.:
self.logger.error("Alpha became infinite! Stopping.")
return energy, controller.ERROR
alpha = previous_gamma/ddotq
if alpha < 0:
self.logger.warn("Positive definiteness of A violated!")
return energy, controller.ERROR
r -= q * alpha
......@@ -105,7 +103,7 @@ class ConjugateGradient(Minimizer):
gamma = r.vdot(s).real
if gamma < 0:
self.logger.warn(
raise RuntimeError(
"Positive definiteness of preconditioner violated!")
if gamma == 0:
return energy, controller.CONVERGED
......
......@@ -56,9 +56,9 @@ class DefaultIterationController(IterationController):
msg += self._name+":"
msg += " Iteration #" + str(self._itcount)
msg += " gradnorm=" + str(energy.gradient_norm)
msg += " convergence level=" + str(self._ccount)
msg += " clvl=" + str(self._ccount)
print (msg)
self.logger.info(msg)
#self.logger.info(msg)
# Are we done?
if self._iteration_limit is not None:
......
......@@ -82,7 +82,6 @@ class DescentMinimizer(Minimizer):
while True:
# check if position is at a flat point
if energy.gradient_norm == 0:
self.logger.info("Reached perfectly flat point. Stopping.")
return energy, controller.CONVERGED
# current position is encoded in energy object
......@@ -96,15 +95,11 @@ class DescentMinimizer(Minimizer):
pk=descent_direction,
f_k_minus_1=f_k_minus_1)
except RuntimeError:
self.logger.warn(
"Stopping because of RuntimeError in line-search")
return energy, controller.ERROR
f_k_minus_1 = energy.value
# check if new energy value is bigger than old energy value
if (new_energy.value - energy.value) > 0:
self.logger.info("Line search algorithm returned a new energy "
"that was larger than the old one. Stopping.")
return energy, controller.ERROR
energy = new_energy
......
......@@ -22,11 +22,10 @@ from ..nifty_meta import NiftyMeta
import numpy as np
from keepers import Loggable
from future.utils import with_metaclass
class IterationController(with_metaclass(NiftyMeta, type('NewBase',
(Loggable, object), {}))):
(object,), {}))):
"""The abstract base class for all iteration controllers.
An iteration controller is an object that monitors the progress of a
minimization iteration. At the begin of the minimization, its start()
......
......@@ -18,13 +18,11 @@
import abc
from keepers import Loggable
from ...energies import LineEnergy
from future.utils import with_metaclass
class LineSearch(with_metaclass(abc.ABCMeta, with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object), {})))):
class LineSearch(with_metaclass(abc.ABCMeta, with_metaclass(abc.ABCMeta, type('NewBase', (object,), {})))):
"""Class for determining the optimal step size along some descent direction.
Initialize the line search procedure which can be used by a specific line
......
......@@ -109,8 +109,7 @@ class LineSearchStrongWolfe(LineSearch):
phi_0 = le_0.value
phiprime_0 = le_0.directional_derivative
if phiprime_0 >= 0:
self.logger.error("Input direction must be a descent direction")
raise RuntimeError
raise RuntimeError ("search direction must be a descent direction")
# set alphas
alpha0 = 0.
......@@ -131,7 +130,6 @@ class LineSearchStrongWolfe(LineSearch):
while iteration_number < self.max_iterations:
iteration_number += 1
if alpha1 == 0:
self.logger.warn("Increment size became 0.")
result_energy = le_0.energy
break
......@@ -161,18 +159,13 @@ class LineSearchStrongWolfe(LineSearch):
# update alphas
alpha0, alpha1 = alpha1, min(2*alpha1, self.max_step_size)
if alpha1 == self.max_step_size:
self.logger.info("Reached max step size, bailing out")
return le_alpha1.energy
phi_alpha0 = phi_alpha1
phiprime_alpha0 = phiprime_alpha1
else:
# max_iterations was reached
self.logger.error("The line search algorithm did not converge.")
return le_alpha1.energy
if iteration_number > 1:
self.logger.debug("Finished line-search after %08u steps" %
iteration_number)
return result_energy
def _zoom(self, alpha_lo, alpha_hi, phi_0, phiprime_0,
......@@ -269,8 +262,8 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alphaj)
else:
self.logger.error("The line search algorithm (zoom) did not "
"converge.")
#self.logger.error("The line search algorithm (zoom) did not "
# "converge.")
return le_alphaj
def _cubicmin(self, a, fa, fpa, b, fb, c, fc):
......
......@@ -21,10 +21,9 @@ from ..nifty_meta import NiftyMeta
import numpy as np
from keepers import Loggable
from future.utils import with_metaclass
class Minimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
class Minimizer(with_metaclass(NiftyMeta, type('NewBase', (object,), {}))):
""" A base class used by all minimizers.
"""
......
......@@ -77,15 +77,6 @@ class GLLMTransformation(SlicingTransformation):
lmax = codomain.lmax
mmax = codomain.mmax
if lmax != mmax:
cls.logger.warn("Unrecommended: codomain has lmax != mmax.")
if lmax != nlat - 1:
cls.logger.warn("Unrecommended: codomain has lmax != nlat - 1.")
if nlon != 2*nlat - 1:
cls.logger.warn("Unrecommended: domain has nlon != 2*nlat - 1.")
super(GLLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
......
......@@ -75,9 +75,6 @@ class HPLMTransformation(SlicingTransformation):
lmax = codomain.lmax
nside = domain.nside
if lmax != 2*nside:
cls.logger.warn("Unrecommended: lmax != 2*nside.")
super(HPLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
......
......@@ -83,15 +83,6 @@ class LMGLTransformation(SlicingTransformation):
lmax = domain.lmax
mmax = domain.mmax
if lmax != mmax:
cls.logger.warn("Unrecommended: codomain has lmax != mmax.")
if nlat != lmax + 1:
cls.logger.warn("Unrecommended: codomain has nlat != lmax + 1.")
if nlon != 2*lmax + 1:
cls.logger.warn("Unrecommended: domain has nlon != 2*lmax + 1.")
super(LMGLTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
......
......@@ -77,9 +77,6 @@ class LMHPTransformation(SlicingTransformation):
nside = codomain.nside
lmax = domain.lmax
if lmax != 2*nside:
cls.logger.warn("Unrecommended: lmax != 2*nside.")
super(LMHPTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
......
......@@ -23,13 +23,12 @@ import warnings
import numpy as np
from .... import nifty_utilities as utilities
from keepers import Loggable
from functools import reduce
import pyfftw
class Transform(Loggable, object):
class Transform(object):
"""
A generic fft object without any implementation.
"""
......
......@@ -18,11 +18,10 @@
import abc
from keepers import Loggable
from future.utils import with_metaclass
class Transformation(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object), {}))):
class Transformation(with_metaclass(abc.ABCMeta, type('NewBase', (object,), {}))):
"""
A generic transformation which defines a static check_codomain
method for all transforms.
......
......@@ -20,14 +20,13 @@ from builtins import str
import abc
from ...nifty_meta import NiftyMeta
from keepers import Loggable
from ...field import Field
from ... import nifty_utilities as utilities
from future.utils import with_metaclass
class LinearOperator(with_metaclass(
NiftyMeta, type('NewBase', (Loggable, object), {}))):
NiftyMeta, type('NewBase', (object,), {}))):
"""NIFTY base class for linear operators.
The base NIFTY operator class is an abstract class from which
......
......@@ -55,20 +55,6 @@ class ProjectionOperator(EndomorphicOperator):
Raised if
* if projection_field is not a Field
Notes
-----
Examples
--------
>>> x_space = RGSpace(5)
>>> f1 = Field(x_space, val=3.)
>>> f2 = Field(x_space, val=5.)
>>> P = ProjectionOperator(f1)
>>> res = P.times(f2)
>>> res.val
<distributed_data_object>
array([ 225., 225., 225., 225., 225.])
See Also
--------
......@@ -104,24 +90,9 @@ class ProjectionOperator(EndomorphicOperator):
for space_index in spaces:
active_axes += x.domain_axes[space_index]
axes_local_distribution_strategy = \
x.val.get_axes_local_distribution_strategy(active_axes)
if axes_local_distribution_strategy == \
self._projection_field.distribution_strategy:
local_projection_vector = \
self._projection_field.val.get_local_data(copy=False)
else:
# create an array that is sub-slice compatible
self.logger.warn("The input field is not sub-slice compatible to "
"the distribution strategy of the operator. "
"Performing an probably expensive "
"redistribution.")
redistr_projection_val = self._projection_field.val.copy(
distribution_strategy=axes_local_distribution_strategy)
local_projection_vector = \
redistr_projection_val.get_local_data(copy=False)
local_x = x.val.get_local_data(copy=False)
local_projection_vector = self._projection_field.val
local_x = x.val
l = len(local_projection_vector.shape)
sublist_projector = list(range(l))
......@@ -141,7 +112,7 @@ class ProjectionOperator(EndomorphicOperator):
dotted, sublist_dotted,
sublist_x)
result_field = x.copy_empty(dtype=remultiplied.dtype)
result_field.val.set_local_data(remultiplied, copy=False)
result_field.val=remultiplied
return result_field
def _inverse_times(self, x, spaces):
......
......@@ -9,8 +9,6 @@ import sys
import numpy as np
from keepers import Loggable
from ...spaces.space import Space
from ...field import Field
from ... import nifty_utilities as utilities
......@@ -24,7 +22,7 @@ if plotly is not None and 'IPython' in sys.modules:
plotly.offline.init_notebook_mode()
class PlotterBase(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object), {}))):
class PlotterBase(with_metaclass(abc.ABCMeta, type('NewBase', (object,), {}))):
def __init__(self, interactive=False, path='plot.html', title=""):
if plotly is None:
raise ImportError("The module plotly is needed but not available.")
......
......@@ -33,9 +33,8 @@ setup(name="ift_nifty",
package_dir={"nifty": "nifty"},
zip_safe=False,
include_dirs=[numpy.get_include()],
dependency_links=[
'git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.7'],
install_requires=['keepers>=0.3.7'],
dependency_links=[],
install_requires=[],
license="GPLv3",
classifiers=[
"Development Status :: 4 - Beta",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment