Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
e8f33125
Commit
e8f33125
authored
Jan 15, 2018
by
Martin Reinecke
Browse files
tweak FFT operator
parent
922878ed
Pipeline
#23731
failed with stage
in 4 minutes and 4 seconds
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter_via_curvature.py
View file @
e8f33125
...
...
@@ -73,6 +73,7 @@ if __name__ == "__main__":
j
=
R_harmonic
.
adjoint_times
(
N
.
inverse_times
(
data
))
print
"xx"
,
j
.
val
[
0
]
*
nu
.
K
*
(
nu
.
m
**
dimensionality
)
exit
()
ctrl
=
ift
.
GradientNormController
(
verbose
=
True
,
tol_abs_gradnorm
=
1e-40
/
(
nu
.
K
*
(
nu
.
m
**
dimensionality
)))
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
...
...
nifty/operators/fft_operator.py
View file @
e8f33125
...
...
@@ -21,18 +21,26 @@ from .. import DomainTuple
from
..spaces
import
RGSpace
from
..utilities
import
infer_space
from
.linear_operator
import
LinearOperator
from
.fft_operator_support
import
RGRGTransformation
,
SphericalTransformation
from
..
import
dobj
from
..
import
utilities
from
..field
import
Field
from
..spaces.gl_space
import
GLSpace
class
FFTOperator
(
LinearOperator
):
"""Transforms between a pair of
harmonic and position
domains.
"""Transforms between a pair of
position and harmonic
domains.
Built-in domain pairs are
- harmonic RGSpace / nonharmonic RGSpace (with matching distances)
- LMSpace / HPSpace
- LMSpace / GLSpace
The times() operation always transforms from the harmonic to the
position domain.
- a harmonic and a non-harmonic RGSpace (with matching distances)
- a HPSpace and a LMSpace
- a GLSpace and a LMSpace
Within a domain pair, both orderings are possible.
For RGSpaces, the operator provides the full set of operations.
For the sphere-related domains, it only supports the transform from
harmonic to position space and its adjoint; if the operator domain is
harmonic, this will be times() and adjoint_times(), otherwise
inverse_times() and adjoint_inverse_times()
Parameters
----------
...
...
@@ -58,33 +66,158 @@ class FFTOperator(LinearOperator):
# Initialize domain and target
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
infer_space
(
self
.
_domain
,
space
)
if
not
self
.
_domain
[
self
.
_space
].
harmonic
:
raise
TypeError
(
"H2POperator must work on a harmonic domain"
)
adom
=
self
.
domain
[
self
.
_space
]
adom
=
self
.
_
domain
[
self
.
_space
]
if
target
is
None
:
target
=
adom
.
get_default_codomain
()
self
.
_target
=
[
dom
for
dom
in
self
.
domain
]
self
.
_target
=
[
dom
for
dom
in
self
.
_
domain
]
self
.
_target
[
self
.
_space
]
=
target
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
hdom
,
pdom
=
(
self
.
_domain
,
self
.
_target
)
if
isinstance
(
pdom
[
self
.
_space
],
RGSpace
):
self
.
_trafo
=
RGRGTransformation
(
hdom
,
pdom
,
self
.
_space
)
if
isinstance
(
adom
,
RGSpace
):
self
.
_applyfunc
=
self
.
_apply_cartesian
self
.
_capability
=
self
.
_all_ops
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
else
:
self
.
_trafo
=
SphericalTransformation
(
hdom
,
pdom
,
self
.
_space
)
from
pyHealpix
import
sharpjob_d
self
.
_applyfunc
=
self
.
_apply_spherical
hspc
=
adom
if
adom
.
harmonic
else
target
pspc
=
target
if
adom
.
harmonic
else
adom
self
.
lmax
=
hspc
.
lmax
self
.
mmax
=
hspc
.
mmax
self
.
sjob
=
sharpjob_d
()
self
.
sjob
.
set_triangular_alm_info
(
self
.
lmax
,
self
.
mmax
)
if
isinstance
(
pspc
,
GLSpace
):
self
.
sjob
.
set_Gauss_geometry
(
pspc
.
nlat
,
pspc
.
nlon
)
else
:
self
.
sjob
.
set_Healpix_geometry
(
pspc
.
nside
)
if
adom
.
harmonic
:
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
else
:
self
.
_capability
=
(
self
.
INVERSE_TIMES
|
self
.
INVERSE_ADJOINT_TIMES
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
np
.
issubdtype
(
x
.
dtype
,
np
.
complexfloating
):
re
s
=
(
self
.
_
trafo
.
apply
(
x
.
real
,
mode
)
+
1j
*
self
.
_
trafo
.
apply
(
x
.
imag
,
mode
))
re
turn
(
self
.
_apply
func
(
x
.
real
,
mode
)
+
1j
*
self
.
_apply
func
(
x
.
imag
,
mode
))
else
:
res
=
self
.
_trafo
.
apply
(
x
,
mode
)
return
res
return
self
.
_applyfunc
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
"""
RG -> RG transform method.
Parameters
----------
x : Field
The field to be transformed
"""
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
if
oldax
not
in
axes
:
# straightforward, no redistribution needed
ldat
=
dobj
.
local_data
(
x
.
val
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat
,
distaxis
=
oldax
)
elif
len
(
axes
)
<
len
(
x
.
shape
)
or
len
(
axes
)
==
1
:
# we can use one Hartley pass in between the redistributions
tmp
=
dobj
.
redistribute
(
x
.
val
,
nodist
=
axes
)
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
else
:
# two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
rem_axes
=
tuple
(
i
for
i
in
axes
if
i
!=
oldax
)
tmp
=
x
.
val
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
my_fftn_r2c
(
ldat
,
axes
=
rem_axes
)
if
oldax
!=
0
:
raise
ValueError
(
"bad distribution"
)
ldat2
=
ldat
.
reshape
((
ldat
.
shape
[
0
],
np
.
prod
(
ldat
.
shape
[
1
:])))
shp2d
=
(
x
.
val
.
shape
[
0
],
np
.
prod
(
x
.
val
.
shape
[
1
:]))
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
).
reshape
(
ldat
.
shape
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat2
,
distaxis
=
0
)
Tval
=
Field
(
tdom
,
tmp
)
if
x
.
domain
[
self
.
_space
].
harmonic
:
if
(
mode
==
LinearOperator
.
TIMES
or
mode
==
LinearOperator
.
ADJOINT_TIMES
):
fct
=
self
.
_domain
[
self
.
_space
].
scalar_dvol
()
else
:
fct
=
1.
/
(
self
.
_domain
[
self
.
_space
].
scalar_dvol
()
*
self
.
_domain
[
self
.
_space
].
dim
)
else
:
if
(
mode
==
LinearOperator
.
TIMES
or
mode
==
LinearOperator
.
ADJOINT_TIMES
):
fct
=
1.
/
(
self
.
_target
[
self
.
_space
].
scalar_dvol
()
*
self
.
_target
[
self
.
_space
].
dim
)
else
:
fct
=
self
.
_target
[
self
.
_space
].
scalar_dvol
()
if
fct
!=
1
:
Tval
*=
fct
return
Tval
def
_slice_p2h
(
self
,
inp
):
rr
=
self
.
sjob
.
alm2map_adjoint
(
inp
)
assert
len
(
rr
)
==
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
)
res
=
np
.
empty
(
2
*
len
(
rr
)
-
self
.
lmax
-
1
,
dtype
=
rr
[
0
].
real
.
dtype
)
res
[
0
:
self
.
lmax
+
1
]
=
rr
[
0
:
self
.
lmax
+
1
].
real
res
[
self
.
lmax
+
1
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
real
res
[
self
.
lmax
+
2
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
imag
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_slice_h2p
(
self
,
inp
):
res
=
np
.
empty
((
len
(
inp
)
+
self
.
lmax
+
1
)
//
2
,
dtype
=
(
inp
[
0
]
*
1j
).
dtype
)
assert
len
(
res
)
==
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
)
res
[
0
:
self
.
lmax
+
1
]
=
inp
[
0
:
self
.
lmax
+
1
]
res
[
self
.
lmax
+
1
:]
=
np
.
sqrt
(
0.5
)
*
(
inp
[
self
.
lmax
+
1
::
2
]
+
1j
*
inp
[
self
.
lmax
+
2
::
2
])
res
=
self
.
sjob
.
alm2map
(
res
)
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_apply_spherical
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
tval
=
x
.
val
if
dobj
.
distaxis
(
tval
)
==
axis
:
tval
=
dobj
.
redistribute
(
tval
,
nodist
=
(
axis
,))
distaxis
=
dobj
.
distaxis
(
tval
)
p2h
=
not
x
.
domain
[
self
.
_space
].
harmonic
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
idat
=
dobj
.
local_data
(
tval
)
odat
=
np
.
empty
(
dobj
.
local_shape
(
tdom
.
shape
,
distaxis
=
distaxis
),
dtype
=
x
.
dtype
)
for
slice
in
utilities
.
get_slice_list
(
idat
.
shape
,
axes
):
odat
[
slice
]
=
func
(
idat
[
slice
])
odat
=
dobj
.
from_local_data
(
tdom
.
shape
,
odat
,
distaxis
)
if
distaxis
!=
dobj
.
distaxis
(
x
.
val
):
odat
=
dobj
.
redistribute
(
odat
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
tdom
,
odat
)
@
property
def
domain
(
self
):
...
...
@@ -96,7 +229,4 @@ class FFTOperator(LinearOperator):
@
property
def
capability
(
self
):
res
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
if
self
.
_trafo
.
unitary
:
res
|=
self
.
INVERSE_TIMES
|
self
.
ADJOINT_INVERSE_TIMES
return
res
return
self
.
_capability
nifty/operators/fft_operator_support.py
deleted
100644 → 0
View file @
922878ed
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
import
numpy
as
np
from
..
import
utilities
from
..
import
dobj
from
..field
import
Field
from
..spaces.gl_space
import
GLSpace
from
.linear_operator
import
LinearOperator
class
Transformation
(
object
):
def
__init__
(
self
,
hdom
,
pdom
,
space
):
self
.
hdom
=
hdom
self
.
pdom
=
pdom
self
.
space
=
space
class
RGRGTransformation
(
Transformation
):
def
__init__
(
self
,
hdom
,
pdom
,
space
):
import
pyfftw
super
(
RGRGTransformation
,
self
).
__init__
(
hdom
,
pdom
,
space
)
pyfftw
.
interfaces
.
cache
.
enable
()
self
.
fct_noninverse
=
hdom
[
space
].
scalar_dvol
()
self
.
fct_inverse
=
1.
/
(
hdom
[
space
].
scalar_dvol
()
*
hdom
[
space
].
dim
)
@
property
def
unitary
(
self
):
return
True
def
apply
(
self
,
x
,
mode
):
"""
RG -> RG transform method.
Parameters
----------
x : Field
The field to be transformed
"""
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
space
]
p2h
=
x
.
domain
==
self
.
pdom
tdom
=
self
.
hdom
if
p2h
else
self
.
pdom
oldax
=
dobj
.
distaxis
(
x
.
val
)
if
oldax
not
in
axes
:
# straightforward, no redistribution needed
ldat
=
dobj
.
local_data
(
x
.
val
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat
,
distaxis
=
oldax
)
elif
len
(
axes
)
<
len
(
x
.
shape
)
or
len
(
axes
)
==
1
:
# we can use one Hartley pass in between the redistributions
tmp
=
dobj
.
redistribute
(
x
.
val
,
nodist
=
axes
)
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
else
:
# two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
if
True
:
rem_axes
=
tuple
(
i
for
i
in
axes
if
i
!=
oldax
)
tmp
=
x
.
val
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
my_fftn_r2c
(
ldat
,
axes
=
rem_axes
)
# new, experimental code
if
True
:
if
oldax
!=
0
:
raise
ValueError
(
"bad distribution"
)
ldat2
=
ldat
.
reshape
((
ldat
.
shape
[
0
],
np
.
prod
(
ldat
.
shape
[
1
:])))
shp2d
=
(
x
.
val
.
shape
[
0
],
np
.
prod
(
x
.
val
.
shape
[
1
:]))
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
).
reshape
(
ldat
.
shape
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat2
,
distaxis
=
0
)
else
:
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
oldax
)
tmp
=
dobj
.
redistribute
(
tmp
,
nodist
=
(
oldax
,))
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
fftn
(
ldat
,
axes
=
(
oldax
,))
ldat
=
ldat
.
real
+
ldat
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
else
:
tmp
=
dobj
.
redistribute
(
x
.
val
,
nodist
=
(
oldax
,))
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
fftn
(
ldat
,
axes
=
(
oldax
,))
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
rem_axes
=
tuple
(
i
for
i
in
axes
if
i
!=
oldax
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
fftn
(
ldat
,
axes
=
rem_axes
)
ldat
=
ldat
.
real
+
ldat
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
oldax
)
Tval
=
Field
(
tdom
,
tmp
)
if
(
mode
==
LinearOperator
.
TIMES
or
mode
==
LinearOperator
.
ADJOINT_TIMES
):
fct
=
self
.
fct_noninverse
else
:
fct
=
self
.
fct_inverse
if
fct
!=
1
:
Tval
*=
fct
return
Tval
class
SphericalTransformation
(
Transformation
):
def
__init__
(
self
,
hdom
,
pdom
,
space
):
super
(
SphericalTransformation
,
self
).
__init__
(
hdom
,
pdom
,
space
)
from
pyHealpix
import
sharpjob_d
self
.
lmax
=
self
.
hdom
[
self
.
space
].
lmax
self
.
mmax
=
self
.
hdom
[
self
.
space
].
mmax
self
.
sjob
=
sharpjob_d
()
self
.
sjob
.
set_triangular_alm_info
(
self
.
lmax
,
self
.
mmax
)
if
isinstance
(
self
.
pdom
[
self
.
space
],
GLSpace
):
self
.
sjob
.
set_Gauss_geometry
(
self
.
pdom
[
self
.
space
].
nlat
,
self
.
pdom
[
self
.
space
].
nlon
)
else
:
self
.
sjob
.
set_Healpix_geometry
(
self
.
pdom
[
self
.
space
].
nside
)
@
property
def
unitary
(
self
):
return
False
def
_slice_p2h
(
self
,
inp
):
rr
=
self
.
sjob
.
alm2map_adjoint
(
inp
)
assert
len
(
rr
)
==
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
)
res
=
np
.
empty
(
2
*
len
(
rr
)
-
self
.
lmax
-
1
,
dtype
=
rr
[
0
].
real
.
dtype
)
res
[
0
:
self
.
lmax
+
1
]
=
rr
[
0
:
self
.
lmax
+
1
].
real
res
[
self
.
lmax
+
1
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
real
res
[
self
.
lmax
+
2
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
imag
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_slice_h2p
(
self
,
inp
):
res
=
np
.
empty
((
len
(
inp
)
+
self
.
lmax
+
1
)
//
2
,
dtype
=
(
inp
[
0
]
*
1j
).
dtype
)
assert
len
(
res
)
==
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
)
res
[
0
:
self
.
lmax
+
1
]
=
inp
[
0
:
self
.
lmax
+
1
]
res
[
self
.
lmax
+
1
:]
=
np
.
sqrt
(
0.5
)
*
(
inp
[
self
.
lmax
+
1
::
2
]
+
1j
*
inp
[
self
.
lmax
+
2
::
2
])
res
=
self
.
sjob
.
alm2map
(
res
)
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
apply
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
space
]
axis
=
axes
[
0
]
tval
=
x
.
val
if
dobj
.
distaxis
(
tval
)
==
axis
:
tval
=
dobj
.
redistribute
(
tval
,
nodist
=
(
axis
,))
distaxis
=
dobj
.
distaxis
(
tval
)
p2h
=
x
.
domain
==
self
.
pdom
tdom
=
self
.
hdom
if
p2h
else
self
.
pdom
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
idat
=
dobj
.
local_data
(
tval
)
odat
=
np
.
empty
(
dobj
.
local_shape
(
tdom
.
shape
,
distaxis
=
distaxis
),
dtype
=
x
.
dtype
)
for
slice
in
utilities
.
get_slice_list
(
idat
.
shape
,
axes
):
odat
[
slice
]
=
func
(
idat
[
slice
])
odat
=
dobj
.
from_local_data
(
tdom
.
shape
,
odat
,
distaxis
)
if
distaxis
!=
dobj
.
distaxis
(
x
.
val
):
odat
=
dobj
.
redistribute
(
odat
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
tdom
,
odat
)
nifty/operators/fft_smoothing_operator.py
View file @
e8f33125
...
...
@@ -16,13 +16,12 @@ def FFTSmoothingOperator(domain, sigma, space=None):
space
=
infer_space
(
domain
,
space
)
if
domain
[
space
].
harmonic
:
raise
TypeError
(
"domain must not be harmonic"
)
fftdom
=
list
(
domain
)
codomain
=
domain
[
space
].
get_default_codomain
()
fftdom
[
space
]
=
codomain
fftdom
=
DomainTuple
.
make
(
fftdom
)
FFT
=
FFTOperator
(
fftdom
,
domain
[
space
],
space
=
space
)
FFT
=
FFTOperator
(
domain
,
space
=
space
)
codomain
=
FFT
.
domain
[
space
].
get_default_codomain
()
kernel
=
codomain
.
get_k_length_array
()
smoother
=
codomain
.
get_fft_smoothing_kernel_function
(
sigma
)
kernel
=
smoother
(
kernel
)
diag
=
DiagonalOperator
(
kernel
,
fftdom
,
space
)
return
FFT
*
diag
*
FFT
.
inverse
ddom
=
list
(
domain
)
ddom
[
space
]
=
codomain
diag
=
DiagonalOperator
(
kernel
,
ddom
,
space
)
return
FFT
.
inverse
*
diag
*
FFT
test/test_operators/test_adjoint.py
View file @
e8f33125
...
...
@@ -39,7 +39,5 @@ class Adjointness_Tests(unittest.TestCase):
@
expand
(
product
(
_harmonic_spaces
+
_position_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testFFT
(
self
,
sp
,
dtype
):
if
not
sp
.
harmonic
:
sp
=
sp
.
get_default_codomain
()
op
=
ift
.
FFTOperator
(
sp
)
_check_adjointness
(
op
,
dtype
)
test/test_operators/test_fft_operator.py
View file @
e8f33125
...
...
@@ -37,8 +37,8 @@ class FFTOperatorTests(unittest.TestCase):
[
np
.
float64
,
np
.
float32
,
np
.
complex64
,
np
.
complex128
]))
def
test_fft1D
(
self
,
dim1
,
d
,
itp
):
tol
=
_get_rtol
(
itp
)
b
=
ift
.
RGSpace
(
dim1
,
distances
=
d
)
a
=
ift
.
RGSpace
(
dim1
,
distances
=
1.
/
(
dim1
*
d
),
harmonic
=
True
)
a
=
ift
.
RGSpace
(
dim1
,
distances
=
d
)
b
=
ift
.
RGSpace
(
dim1
,
distances
=
1.
/
(
dim1
*
d
),
harmonic
=
True
)
fft
=
ift
.
FFTOperator
(
domain
=
a
,
target
=
b
)
np
.
random
.
seed
(
16
)
...
...
@@ -53,8 +53,8 @@ class FFTOperatorTests(unittest.TestCase):
[
np
.
float64
,
np
.
float32
,
np
.
complex64
,
np
.
complex128
]))
def
test_fft2D
(
self
,
dim1
,
dim2
,
d1
,
d2
,
itp
):
tol
=
_get_rtol
(
itp
)
b
=
ift
.
RGSpace
([
dim1
,
dim2
],
distances
=
[
d1
,
d2
])
a
=
ift
.
RGSpace
([
dim1
,
dim2
],
a
=
ift
.
RGSpace
([
dim1
,
dim2
],
distances
=
[
d1
,
d2
])
b
=
ift
.
RGSpace
([
dim1
,
dim2
],
distances
=
[
1.
/
(
dim1
*
d1
),
1.
/
(
dim2
*
d2
)],
harmonic
=
True
)
fft
=
ift
.
FFTOperator
(
domain
=
a
,
target
=
b
)
...
...
@@ -78,8 +78,8 @@ class FFTOperatorTests(unittest.TestCase):
assert_allclose
(
ift
.
dobj
.
to_global_data
(
inp
.
val
),
ift
.
dobj
.
to_global_data
(
out
.
val
),
rtol
=
tol
,
atol
=
tol
)
@
expand
(
product
([
0
,
3
,
6
,
11
,
30
],
[
np
.
float64
,
np
.
float32
,
np
.
complex64
,
np
.
complex128
]))
#
@expand(product([0, 3, 6, 11, 30],
#
[np.float64, np.float32, np.complex64, np.complex128]))
#def test_sht(self, lm, tp):
# tol = _get_rtol(tp)
# a = ift.LMSpace(lmax=lm)
...
...
@@ -130,3 +130,15 @@ class FFTOperatorTests(unittest.TestCase):
v1
=
np
.
sqrt
(
out
.
vdot
(
out
))
v2
=
np
.
sqrt
(
inp
.
vdot
(
fft
.
adjoint_times
(
out
)))
assert_allclose
(
v1
,
v2
,
rtol
=
tol
,
atol
=
tol
)
@
expand
(
product
([
ift
.
RGSpace
(
128
,
distances
=
3.76
,
harmonic
=
True
),
ift
.
LMSpace
(
lmax
=
30
,
mmax
=
25
)],
[
np
.
float64
,
np
.
float32
,
np
.
complex64
,
np
.
complex128
]))
def
test_normalisation
(
self
,
space
,
tp
):
tol
=
10
*
_get_rtol
(
tp
)
fft
=
ift
.
FFTOperator
(
space
)
inp
=
ift
.
Field
.
from_random
(
domain
=
space
,
random_type
=
'normal'
,
std
=
1
,
mean
=
2
,
dtype
=
tp
)
out
=
fft
.
times
(
inp
)
assert_allclose
(
ift
.
dobj
.
to_global_data
(
inp
.
val
)[
0
],
out
.
integrate
(),
rtol
=
tol
,
atol
=
tol
)
test/test_operators/test_response_operator.py
View file @
e8f33125
...
...
@@ -9,18 +9,18 @@ class ResponseOperator_Tests(unittest.TestCase):
spaces
=
[
ift
.
RGSpace
(
128
),
ift
.
GLSpace
(
nlat
=
37
)]
@
expand
(
product
(
spaces
,
[
0.
,
5.
,
1.
],
[
0.
,
1.
,
.
33
]))
def
test_property
(
self
,
space
,
sigma
,
exposure
):
def
test_property
(
self
,
space
,
sigma
,
sensitivity
):
op
=
ift
.
ResponseOperator
(
space
,
sigma
=
[
sigma
],
exposure
=
[
exposure
])
sensitivity
=
[
sensitivity
])
if
op
.
domain
[
0
]
!=
space
:
raise
TypeError
@
expand
(
product
(
spaces
,
[
0.
,
5.
,
1.
],
[
0.
,
1.
,
.
33
]))
def
test_times_adjoint_times
(
self
,
space
,
sigma
,
exposure
):
def
test_times_adjoint_times
(
self
,
space
,
sigma
,
sensitivity
):
if
not
isinstance
(
space
,
ift
.
RGSpace
):
# no smoothing supported
sigma
=
0.
op
=
ift
.
ResponseOperator
(
space
,
sigma
=
[
sigma
],
exposure
=
[
exposure
])
sensitivity
=
[
sensitivity
])
rand1
=
ift
.
Field
.
from_random
(
'normal'
,
domain
=
space
)
rand2
=
ift
.
Field
.
from_random
(
'normal'
,
domain
=
op
.
target
[
0
])
tt1
=
rand2
.
vdot
(
op
.
times
(
rand1
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment