Commit 93d2e2d6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge master

parents 4f5692dd 05e24a79
Pipeline #21184 passed with stage
in 4 minutes and 16 seconds
...@@ -7,7 +7,6 @@ from .diagonal_operator import DiagonalOperator ...@@ -7,7 +7,6 @@ from .diagonal_operator import DiagonalOperator
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from .fft_smoothing_operator import FFTSmoothingOperator from .fft_smoothing_operator import FFTSmoothingOperator
from .direct_smoothing_operator import DirectSmoothingOperator
from .fft_operator import FFTOperator from .fft_operator import FFTOperator
......
from __future__ import division
from builtins import range
import numpy as np
from .endomorphic_operator import EndomorphicOperator
from ..spaces import PowerSpace
from .. import nifty_utilities as utilities
from .. import Field, DomainTuple
class DirectSmoothingOperator(EndomorphicOperator):
def __init__(self, domain, sigma, log_distances=False,
space=None):
super(DirectSmoothingOperator, self).__init__()
self._domain = DomainTuple.make(domain)
if space is None:
if len(self._domain.domains) != 1:
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if (space<0) or space>=len(self._domain.domains):
raise ValueError("space index out of range")
if not isinstance(self._domain[space],PowerSpace):
raise TypeError("PowerSpace needed")
self._space = space
self._sigma = float(sigma)
self._log_distances = log_distances
self._effective_smoothing_width = 3.01
def _times(self, x):
if self._sigma == 0:
return x.copy()
return self._smooth(x)
# ---Mandatory properties and methods---
@property
def domain(self):
return self._domain
@property
def self_adjoint(self):
return True
@property
def unitary(self):
return False
# ---Added properties and methods---
def _precompute(self, x):
""" Does precomputations for Gaussian smoothing on a 1D irregular grid.
Parameters
----------
x: 1D floating point array or list containing the individual grid
positions. Points must be given in ascending order.
Returns
-------
ibegin: integer array of the same size as x
ibegin[i] is the minimum grid index to consider when computing the
smoothed value at grid index i
nval: integer array of the same size as x
nval[i] is the number of indices to consider when computing the
smoothed value at grid index i.
wgt: list with the same number of entries as x
wgt[i] is an array with nval[i] entries containing the
normalized smoothing weights.
"""
dxmax = self._effective_smoothing_width*self._sigma
x = np.asarray(x)
ibegin = np.searchsorted(x, x-dxmax)
nval = np.searchsorted(x, x+dxmax) - ibegin
wgt = []
expfac = 1. / (2. * self._sigma*self._sigma)
for i in range(x.size):
if nval[i] > 0:
t = x[ibegin[i]:ibegin[i]+nval[i]]-x[i]
t = np.exp(-t*t*expfac)
t *= 1./np.sum(t)
wgt.append(t)
else:
wgt.append(np.array([]))
return ibegin, nval, wgt
def _smooth(self, x):
# infer affected axes
affected_axes = x.domain.axes[self._space]
axis = affected_axes[0]
distances = x.domain[self._space].k_lengths
if self._log_distances:
distances = np.log(np.maximum(distances, 1e-15))
ibegin, nval, wgt = self._precompute(distances)
res = Field.empty_like(x)
for sl in utilities.get_slice_list(x.val.shape, (axis,)):
inp = x.val[sl]
out = np.zeros(inp.shape[0], dtype=inp.dtype)
for i in range(inp.shape[0]):
out[ibegin[i]:ibegin[i]+nval[i]] += inp[i] * wgt[i][:]
res.val[sl] = out
return res
...@@ -60,13 +60,21 @@ class RGRGTransformation(Transformation): ...@@ -60,13 +60,21 @@ class RGRGTransformation(Transformation):
""" """
axes = x.domain.axes[self.space] axes = x.domain.axes[self.space]
p2h = x.domain == self.pdom p2h = x.domain == self.pdom
tdom = self.hdom if p2h else self.pdom
if dobj.dist_axis(x.val) in axes: if dobj.dist_axis(x.val) in axes:
raise NotImplementedError tmpax = (dobj.dist_axis(x.val),)
ldat = dobj.local_data(x.val) tmp = dobj.redistribute(x.val, nodist=tmpax)
if p2h: ldat = dobj.local_data(tmp)
Tval = Field(self.hdom, dobj.create_from_template(x.val,hartley(ldat, axes),dtype=x.val.dtype)) tmp = dobj.from_local_data(tmp.shape,hartley(ldat,tmpax),dist_axis=dobj.dist_axis(tmp))
tmp = dobj.redistribute(tmp, dist=tmpax[0])
tmpax = tuple (i for i in axes if i not in tmpax)
ldat = dobj.local_data(tmp)
tmp = dobj.from_local_data(tmp.shape,hartley(ldat,tmpax),dist_axis=dobj.dist_axis(tmp))
Tval = Field(tdom, tmp)
else: else:
Tval = Field(self.pdom, dobj.create_from_template(x.val,hartley(ldat, axes),dtype=x.val.dtype)) ldat = dobj.local_data(x.val)
tmp = dobj.from_local_data(x.val.shape,hartley(ldat,axes),dist_axis=dobj.dist_axis(x.val))
Tval = Field(tdom,tmp)
fct = self.fct_p2h if p2h else self.fct_h2p fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1: if fct != 1:
Tval *= fct Tval *= fct
......
...@@ -82,29 +82,3 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -82,29 +82,3 @@ class SmoothingOperator_Tests(unittest.TestCase):
mean=4, dtype=tp) mean=4, dtype=tp)
out = smo(inp) out = smo(inp)
assert_allclose(inp.sum(), out.sum(), rtol=tol, atol=tol) assert_allclose(inp.sum(), out.sum(), rtol=tol, atol=tol)
@expand(product([100, 200], [False, True], [0., 1., 3.7],
[np.float64, np.complex128]))
def test_smooth_irregular1(self, sz, log, sigma, tp):
tol = _get_rtol(tp)
sp = ift.RGSpace(sz, harmonic=True)
bb = ift.PowerSpace.useful_binbounds(sp, logarithmic=log)
ps = ift.PowerSpace(sp, binbounds=bb)
smo = ift.DirectSmoothingOperator(ps, sigma=sigma)
inp = ift.Field.from_random(domain=ps, random_type='normal', std=1,
mean=4, dtype=tp)
out = smo(inp)
assert_allclose(inp.sum(), out.sum(), rtol=tol, atol=tol)
@expand(product([10, 15], [7, 10], [False, True], [0., 1., 3.7],
[np.float64, np.complex128]))
def test_smooth_irregular2(self, sz1, sz2, log, sigma, tp):
tol = _get_rtol(tp)
sp = ift.RGSpace([sz1, sz2], harmonic=True)
bb = ift.PowerSpace.useful_binbounds(sp, logarithmic=log)
ps = ift.PowerSpace(sp, binbounds=bb)
smo = ift.DirectSmoothingOperator(ps, sigma=sigma)
inp = ift.Field.from_random(domain=ps, random_type='normal', std=1,
mean=4, dtype=tp)
out = smo(inp)
assert_allclose(inp.sum(), out.sum(), rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment