Commit e67c236f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge nifty2go

parents 87bae7e9 edca48d9
Pipeline #22305 passed with stage
in 4 minutes and 36 seconds
......@@ -104,14 +104,12 @@ if __name__ == "__main__":
flat_power = ift.Field.full(p_space, 1e-8)
m0 = ift.power_synthesize(flat_power, real_signal=True)
def ps0(k):
return (1./(1.+k)**2)
t0 = ift.Field(p_space,
val=ift.dobj.from_global_data(np.log(1./(1+p_space.k_lengths)**2)))
val=ift.dobj.from_global_data(-7.))
for i in range(500):
S0 = ift.create_power_operator(h_space, power_spectrum=ps0)
S0 = ift.create_power_operator(h_space, power_spectrum=ift.exp(t0))
# Initialize non-linear Wiener Filter energy
ICI = ift.GradientNormController(verbose=False, name="ICI",
......
......@@ -82,3 +82,6 @@ Significant differences between NIFTy nightly and nifty2go
14) A new approach is used for FFTs along axes that are distributed among
MPI tasks. As a consequence, nifty2go works well with the standard version
of pyfftw and does not need the MPI-enabled fork.
15) Arithmetic functions working on Fields have been moved from
basic_arithmetics.py to field.py.
......@@ -73,7 +73,7 @@ class CriticalPowerEnergy(Energy):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
logarithmic=self.logarithmic,
w=self.w, samples=self.samples,
samples=self.samples, w=self.w,
inverter=self._inverter)
@property
......
......@@ -20,10 +20,10 @@ class WienerFilterCurvature(EndomorphicOperator):
"""
def __init__(self, R, N, S):
super(WienerFilterCurvature, self).__init__()
self.R = R
self.N = N
self.S = S
super(WienerFilterCurvature, self).__init__()
@property
def preconditioner(self):
......
from ..minimization.energy import Energy
from ..utilities import memo
from ..operators.inversion_enabler import InversionEnabler
from .wiener_filter_curvature import WienerFilterCurvature
......@@ -26,41 +25,31 @@ class WienerFilterEnergy(Energy):
def __init__(self, position, d, R, N, S, inverter, _j=None):
super(WienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
self._curvature = InversionEnabler(WienerFilterCurvature(R, N, S),
inverter=inverter)
self._inverter = inverter
self._jpre = _j
if _j is None:
_j = self.R.adjoint_times(self.N.inverse_times(d))
self._j = _j
Dx = self._curvature(self.position)
self._value = 0.5*self.position.vdot(Dx) - self._j.vdot(self.position)
self._gradient = Dx - self._j
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, inverter=self._inverter, _j=self._jpre)
return self.__class__(position=position, d=None, R=self.R, N=self.N,
S=self.S, inverter=self._inverter, _j=self._j)
@property
@memo
def value(self):
return 0.5*self.position.vdot(self._Dx) - self._j.vdot(self.position)
return self._value
@property
@memo
def gradient(self):
return self._Dx - self._j
return self._gradient
@property
@memo
def curvature(self):
return InversionEnabler(WienerFilterCurvature(R=self.R, N=self.N,
S=self.S),
inverter=self._inverter)
@property
@memo
def _Dx(self):
return self.curvature(self.position)
@property
def _j(self):
if self._jpre is None:
self._jpre = self.R.adjoint_times(self.N.inverse_times(self.d))
return self._jpre
return self._curvature
......@@ -153,12 +153,25 @@ def _get_kw(kwname, kwdefault=None, **kwargs):
def plot(f, **kwargs):
import matplotlib.pyplot as plt
_register_cmaps()
if not isinstance(f, Field):
if isinstance(f, Field):
f = [f]
if not isinstance(f, list):
raise TypeError("incorrect data type")
if len(f.domain) != 1:
raise ValueError("input field must have exactly one domain")
for i, fld in enumerate(f):
if not isinstance(fld, Field):
raise TypeError("incorrect data type")
if i == 0:
dom = fld.domain
if len(dom) != 1:
raise ValueError("input field must have exactly one domain")
else:
if fld.domain != dom:
raise ValueError("domain mismatch")
if not (isinstance(dom[0], PowerSpace) or
(isinstance(dom[0], RGSpace) and len(dom[0].shape)==1)):
raise ValueError("PowerSpace or 1D RGSpace required")
dom = f.domain[0]
dom = dom[0]
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
......@@ -174,12 +187,14 @@ def plot(f, **kwargs):
npoints = dom.shape[0]
dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist
ycoord = dobj.to_global_data(f.val)
plt.plot(xcoord, ycoord)
for fld in f:
ycoord = dobj.to_global_data(fld.val)
plt.plot(xcoord, ycoord)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
return
elif len(dom.shape) == 2:
f = f[0]
nx = dom.shape[0]
ny = dom.shape[1]
dx = dom.distances[0]
......@@ -199,16 +214,18 @@ def plot(f, **kwargs):
_makeplot(kwargs.get("name"))
return
elif isinstance(dom, PowerSpace):
xcoord = dom.k_lengths
ycoord = dobj.to_global_data(f.val)
plt.xscale('log')
plt.yscale('log')
plt.title('power')
plt.plot(xcoord, ycoord)
xcoord = dom.k_lengths
for fld in f:
ycoord = dobj.to_global_data(fld.val)
plt.plot(xcoord, ycoord)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
return
elif isinstance(dom, HPSpace):
f = f[0]
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
......@@ -225,6 +242,7 @@ def plot(f, **kwargs):
_makeplot(kwargs.get("name"))
return
elif isinstance(dom, GLSpace):
f = f[0]
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
......
......@@ -30,10 +30,10 @@ setup(name="nifty2go",
packages=["nifty2go"] + ["nifty2go."+p for p in find_packages("nifty")],
zip_safe=False,
dependency_links=[
'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git@setuptools_test#egg=pyHealpix-0.0.1'],
'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git#egg=pyHealpix-0.0.1'],
license="GPLv3",
setup_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'],
install_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'],
setup_requires=['future', 'numpy'],
install_requires=['future', 'numpy'],
classifiers=[
"Development Status :: 4 - Beta",
"Topic :: Utilities",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment