Commit 7142a65b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'spectra_partial_merge' into 'NIFTy_5'

uncontroversial changes from operator_spectra

See merge request !345
parents 12c5065b e6a88836
......@@ -2,6 +2,7 @@
# custom
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2019 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty5 as ift
import numpy as np
def plot_test():
rg_space1 = ift.makeDomain(ift.RGSpace((100,)))
rg_space2 = ift.makeDomain(ift.RGSpace((80, 60), distances=1))
hp_space = ift.makeDomain(ift.HPSpace(64))
gl_space = ift.makeDomain(ift.GLSpace(128))
fft = ift.FFTOperator(rg_space2)
field_rg1_1 = ift.Field.from_global_data(rg_space1, np.random.randn(100))
field_rg1_2 = ift.Field.from_global_data(rg_space1, np.random.randn(100))
field_rg2 = ift.Field.from_global_data(
rg_space2, np.random.randn(80*60).reshape((80, 60)))
field_hp = ift.Field.from_global_data(hp_space, np.random.randn(12*64**2))
field_gl = ift.Field.from_global_data(gl_space, np.random.randn(32640))
field_ps = ift.power_analyze(fft.times(field_rg2))
# Start various plotting tests
plot = ift.Plot()
plot.add(field_rg1_1, title='Single plot')
plot = ift.Plot()
plot.add(field_rg2, title='2d rg')
plot.add([field_rg1_1, field_rg1_2], title='list 1d rg', label=['1', '2'])
plot.add(field_rg1_2, title='1d rg, xmin, ymin', xmin=0.5, ymin=0.,
xlabel='xmin=0.5', ylabel='ymin=0')
plot.output(title='Three plots')
plot = ift.Plot()
plot.add(field_hp, title='HP planck-color', colormap='Planck-like')
plot.add(field_rg1_2, title='1d rg')
plot.add(field_gl, title='GL')
plot.add(field_rg2, title='2d rg')
plot.output(nx=2, ny=3, title='Five plots')
if __name__ == '__main__':
......@@ -17,7 +17,7 @@ NIFTy-related publications
date = {2018-04-05},
author = {{Selig}, M. and {Bell}, M.~R. and {Junklewitz}, H. and {Oppermann}, N. and {Reinecke}, M. and {Greiner}, M. and {Pachajoa}, C. and {En{\ss}lin}, T.~A.},
title = "{NIFTY - Numerical Information Field Theory. A versatile PYTHON library for signal inference}",
journal = {\aap},
......@@ -35,7 +35,7 @@ NIFTy-related publications
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
author = {{Steininger}, T. and {Dixit}, J. and {Frank}, P. and {Greiner}, M. and {Hutschenreuter}, S. and {Knollm{\"u}ller}, J. and {Leike}, R. and {Porqueres}, N. and {Pumpe}, D. and {Reinecke}, M. and {{\v S}raml}, M. and {Varady}, C. and {En{\ss}lin}, T.},
title = "{NIFTy 3 - Numerical Information Field Theory - A Python framework for multicomponent signal inference on HPC clusters}",
journal = {ArXiv e-prints},
......@@ -70,6 +70,9 @@ def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
def _check_linearity(op, domain_dtype, atol, rtol):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
fld1 = from_random("normal", op.domain, dtype=domain_dtype)
fld2 = from_random("normal", op.domain, dtype=domain_dtype)
alpha = np.random.random() # FIXME: this can break badly with MPI!
......@@ -121,6 +124,9 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
raise TypeError('This test tests only linear operators.')
_check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
_check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol,
_full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol,
......@@ -26,7 +26,7 @@ def nthreads():
def set_nthreads(nthr):
global _nthreads
_nthreads = nthr
_nthreads = int(nthr)
def fftn(a, axes=None):
......@@ -15,12 +15,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from import RGSpace
from import UnstructuredDomain
from ..operators.linear_operator import LinearOperator
from ..sugar import from_global_data, makeDomain
import numpy as np
class GridderMaker(object):
......@@ -74,27 +74,27 @@ class ConjugateGradient(Minimizer):
if previous_gamma == 0:
return energy, controller.CONVERGED
iter = 0
ii = 0
while True:
q = energy.apply_metric(d)
ddotq = d.vdot(q).real
if ddotq == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.")
curv = d.vdot(q).real
if curv == 0.:
logger.error("Error: ConjugateGradient: curv==0.")
return energy, controller.ERROR
alpha = previous_gamma/ddotq
alpha = previous_gamma/curv
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
return energy, controller.ERROR
iter += 1
if iter < self._nreset:
ii += 1
if ii < self._nreset:
r = r - q*alpha
energy = energy.at_with_grad(energy.position - alpha*d, r)
energy = - alpha*d)
r = energy.gradient
iter = 0
ii = 0
s = r if preconditioner is None else preconditioner(r)
......@@ -15,6 +15,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator
......@@ -46,11 +48,10 @@ class BlockDiagonalOperator(EndomorphicOperator):
for op, v in zip(self._ops, x.values()))
return MultiField(self._domain, val)
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# dtype = MultiField.build_dtype(dtype, self._domain)
# val = tuple(op.draw_sample(from_inverse, dtype)
# for op in self._op)
# return MultiField(self._domain, val)
def draw_sample(self, from_inverse=False, dtype=np.float64):
val = tuple(op.draw_sample(from_inverse, dtype)
if op is not None else None for op in self._ops)
return MultiField(self._domain, val)
def _combine_chain(self, op):
if self._domain != op._domain:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment