Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1883e5d9
Commit
1883e5d9
authored
Mar 20, 2018
by
Martin Reinecke
Browse files
put SHTs into their own operator class
parent
aff5cebf
Pipeline
#26299
passed with stage
in 5 minutes and 12 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty4/operators/harmonic_transform_operator.py
View file @
1883e5d9
...
@@ -16,23 +16,20 @@
...
@@ -16,23 +16,20 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..domain_tuple
import
DomainTuple
from
..domains.rg_space
import
RGSpace
from
..domains.rg_space
import
RGSpace
from
.linear_operator
import
LinearOperator
from
.linear_operator
import
LinearOperator
from
..
import
dobj
from
.fft_operator
import
FFTOperator
from
.sht_operator
import
SHTOperator
from
..
import
utilities
from
..
import
utilities
from
..field
import
Field
from
..domains.gl_space
import
GLSpace
class
HarmonicTransformOperator
(
LinearOperator
):
class
HarmonicTransformOperator
(
LinearOperator
):
"""Transforms between a harmonic domain and a position domain counterpart.
"""Transforms between a harmonic domain and a position domain counterpart.
Built-in domain pairs are
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances)
- a harmonic and a non-harmonic RGSpace (with matching distances)
- an LMSpace and a
LM
Space
- an LMSpace and a
HP
Space
- an LMSpace and a GLSpace
- an LMSpace and a GLSpace
The supported operations are times() and adjoint_times().
The supported operations are times() and adjoint_times().
...
@@ -57,147 +54,29 @@ class HarmonicTransformOperator(LinearOperator):
...
@@ -57,147 +54,29 @@ class HarmonicTransformOperator(LinearOperator):
def
__init__
(
self
,
domain
,
target
=
None
,
space
=
None
):
def
__init__
(
self
,
domain
,
target
=
None
,
space
=
None
):
super
(
HarmonicTransformOperator
,
self
).
__init__
()
super
(
HarmonicTransformOperator
,
self
).
__init__
()
# Initialize domain and target
domain
=
DomainTuple
.
make
(
domain
)
self
.
_domain
=
DomainTuple
.
make
(
domain
)
space
=
utilities
.
infer_space
(
domain
,
space
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
hspc
=
self
.
_
domain
[
self
.
_
space
]
hspc
=
domain
[
space
]
if
not
hspc
.
harmonic
:
if
not
hspc
.
harmonic
:
raise
TypeError
(
raise
TypeError
(
"HarmonicTransformOperator only works on a harmonic space"
)
"HarmonicTransformOperator only works on a harmonic space"
)
if
target
is
None
:
target
=
hspc
.
get_default_codomain
()
self
.
_target
=
[
dom
for
dom
in
self
.
_domain
]
self
.
_target
[
self
.
_space
]
=
target
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
hspc
.
check_codomain
(
target
)
target
.
check_codomain
(
hspc
)
if
isinstance
(
hspc
,
RGSpace
):
if
isinstance
(
hspc
,
RGSpace
):
self
.
_applyfunc
=
self
.
_apply_cartesian
self
.
_op
=
FFTOperator
(
domain
,
target
,
space
)
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
else
:
else
:
from
pyHealpix
import
sharpjob_d
self
.
_op
=
SHTOperator
(
domain
,
target
,
space
)
self
.
_applyfunc
=
self
.
_apply_spherical
self
.
lmax
=
hspc
.
lmax
self
.
mmax
=
hspc
.
mmax
self
.
sjob
=
sharpjob_d
()
self
.
sjob
.
set_triangular_alm_info
(
self
.
lmax
,
self
.
mmax
)
if
isinstance
(
target
,
GLSpace
):
self
.
sjob
.
set_Gauss_geometry
(
target
.
nlat
,
target
.
nlon
)
else
:
self
.
sjob
.
set_Healpix_geometry
(
target
.
nside
)
def
apply
(
self
,
x
,
mode
):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
self
.
_check_input
(
x
,
mode
)
if
np
.
issubdtype
(
x
.
dtype
,
np
.
complexfloating
):
return
self
.
_op
.
apply
(
x
,
mode
)
return
(
self
.
_applyfunc
(
x
.
real
,
mode
)
+
1j
*
self
.
_applyfunc
(
x
.
imag
,
mode
))
else
:
return
self
.
_applyfunc
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
if
oldax
not
in
axes
:
# straightforward, no redistribution needed
ldat
=
x
.
local_data
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat
,
distaxis
=
oldax
)
elif
len
(
axes
)
<
len
(
x
.
shape
)
or
len
(
axes
)
==
1
:
# we can use one Hartley pass in between the redistributions
tmp
=
dobj
.
redistribute
(
x
.
val
,
nodist
=
axes
)
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
else
:
# two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
rem_axes
=
tuple
(
i
for
i
in
axes
if
i
!=
oldax
)
tmp
=
x
.
val
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
my_fftn_r2c
(
ldat
,
axes
=
rem_axes
)
if
oldax
!=
0
:
raise
ValueError
(
"bad distribution"
)
ldat2
=
ldat
.
reshape
((
ldat
.
shape
[
0
],
np
.
prod
(
ldat
.
shape
[
1
:])))
shp2d
=
(
x
.
val
.
shape
[
0
],
np
.
prod
(
x
.
val
.
shape
[
1
:]))
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
).
reshape
(
ldat
.
shape
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat2
,
distaxis
=
0
)
Tval
=
Field
(
tdom
,
tmp
)
fct
=
self
.
_domain
[
self
.
_space
].
scalar_dvol
if
fct
!=
1
:
Tval
*=
fct
return
Tval
def
_slice_p2h
(
self
,
inp
):
rr
=
self
.
sjob
.
alm2map_adjoint
(
inp
)
if
len
(
rr
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
=
np
.
empty
(
2
*
len
(
rr
)
-
self
.
lmax
-
1
,
dtype
=
rr
[
0
].
real
.
dtype
)
res
[
0
:
self
.
lmax
+
1
]
=
rr
[
0
:
self
.
lmax
+
1
].
real
res
[
self
.
lmax
+
1
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
real
res
[
self
.
lmax
+
2
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
imag
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_slice_h2p
(
self
,
inp
):
res
=
np
.
empty
((
len
(
inp
)
+
self
.
lmax
+
1
)
//
2
,
dtype
=
(
inp
[
0
]
*
1j
).
dtype
)
if
len
(
res
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
[
0
:
self
.
lmax
+
1
]
=
inp
[
0
:
self
.
lmax
+
1
]
res
[
self
.
lmax
+
1
:]
=
np
.
sqrt
(
0.5
)
*
(
inp
[
self
.
lmax
+
1
::
2
]
+
1j
*
inp
[
self
.
lmax
+
2
::
2
])
res
=
self
.
sjob
.
alm2map
(
res
)
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_apply_spherical
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
tval
=
x
.
val
if
dobj
.
distaxis
(
tval
)
==
axis
:
tval
=
dobj
.
redistribute
(
tval
,
nodist
=
(
axis
,))
distaxis
=
dobj
.
distaxis
(
tval
)
p2h
=
not
x
.
domain
[
self
.
_space
].
harmonic
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
idat
=
dobj
.
local_data
(
tval
)
odat
=
np
.
empty
(
dobj
.
local_shape
(
tdom
.
shape
,
distaxis
=
distaxis
),
dtype
=
x
.
dtype
)
for
slice
in
utilities
.
get_slice_list
(
idat
.
shape
,
axes
):
odat
[
slice
]
=
func
(
idat
[
slice
])
odat
=
dobj
.
from_local_data
(
tdom
.
shape
,
odat
,
distaxis
)
if
distaxis
!=
dobj
.
distaxis
(
x
.
val
):
odat
=
dobj
.
redistribute
(
odat
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
tdom
,
odat
)
@
property
@
property
def
domain
(
self
):
def
domain
(
self
):
return
self
.
_domain
return
self
.
_
op
.
domain
@
property
@
property
def
target
(
self
):
def
target
(
self
):
return
self
.
_target
return
self
.
_
op
.
target
@
property
@
property
def
capability
(
self
):
def
capability
(
self
):
...
...
nifty4/operators/sht_operator.py
0 → 100644
View file @
1883e5d9
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
.linear_operator
import
LinearOperator
from
..
import
dobj
from
..
import
utilities
from
..field
import
Field
from
..domains.gl_space
import
GLSpace
from
..domains.lm_space
import
LMSpace
class
SHTOperator
(
LinearOperator
):
"""Transforms between a harmonic domain on the sphere and a position
domain counterpart.
Built-in domain pairs are
- an LMSpace and a HPSpace
- an LMSpace and a GLSpace
The supported operations are times() and adjoint_times().
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times".
target : Domain, optional
The target domain of the transform operation.
If omitted, a domain will be chosen automatically.
Whenever the input domain of the transform is an RGSpace, the codomain
(and its parameters) are uniquely determined.
For LMSpace, a GLSpace of sufficient resolution is chosen.
space : int, optional
The index of the domain on which the operator should act
If None, it is set to 0 if domain contains exactly one subdomain.
domain[space] must be a LMSpace.
"""
def
__init__
(
self
,
domain
,
target
=
None
,
space
=
None
):
super
(
SHTOperator
,
self
).
__init__
()
# Initialize domain and target
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
hspc
=
self
.
_domain
[
self
.
_space
]
if
not
isinstance
(
hspc
,
LMSpace
):
raise
TypeError
(
"SHTOperator only works on a LMSpace domain"
)
if
target
is
None
:
target
=
hspc
.
get_default_codomain
()
self
.
_target
=
[
dom
for
dom
in
self
.
_domain
]
self
.
_target
[
self
.
_space
]
=
target
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
hspc
.
check_codomain
(
target
)
target
.
check_codomain
(
hspc
)
from
pyHealpix
import
sharpjob_d
self
.
lmax
=
hspc
.
lmax
self
.
mmax
=
hspc
.
mmax
self
.
sjob
=
sharpjob_d
()
self
.
sjob
.
set_triangular_alm_info
(
self
.
lmax
,
self
.
mmax
)
if
isinstance
(
target
,
GLSpace
):
self
.
sjob
.
set_Gauss_geometry
(
target
.
nlat
,
target
.
nlon
)
else
:
self
.
sjob
.
set_Healpix_geometry
(
target
.
nside
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
np
.
issubdtype
(
x
.
dtype
,
np
.
complexfloating
):
return
(
self
.
_apply_spherical
(
x
.
real
,
mode
)
+
1j
*
self
.
_apply_spherical
(
x
.
imag
,
mode
))
else
:
return
self
.
_apply_spherical
(
x
,
mode
)
def
_slice_p2h
(
self
,
inp
):
rr
=
self
.
sjob
.
alm2map_adjoint
(
inp
)
if
len
(
rr
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
=
np
.
empty
(
2
*
len
(
rr
)
-
self
.
lmax
-
1
,
dtype
=
rr
[
0
].
real
.
dtype
)
res
[
0
:
self
.
lmax
+
1
]
=
rr
[
0
:
self
.
lmax
+
1
].
real
res
[
self
.
lmax
+
1
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
real
res
[
self
.
lmax
+
2
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
imag
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_slice_h2p
(
self
,
inp
):
res
=
np
.
empty
((
len
(
inp
)
+
self
.
lmax
+
1
)
//
2
,
dtype
=
(
inp
[
0
]
*
1j
).
dtype
)
if
len
(
res
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
[
0
:
self
.
lmax
+
1
]
=
inp
[
0
:
self
.
lmax
+
1
]
res
[
self
.
lmax
+
1
:]
=
np
.
sqrt
(
0.5
)
*
(
inp
[
self
.
lmax
+
1
::
2
]
+
1j
*
inp
[
self
.
lmax
+
2
::
2
])
res
=
self
.
sjob
.
alm2map
(
res
)
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_apply_spherical
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
tval
=
x
.
val
if
dobj
.
distaxis
(
tval
)
==
axis
:
tval
=
dobj
.
redistribute
(
tval
,
nodist
=
(
axis
,))
distaxis
=
dobj
.
distaxis
(
tval
)
p2h
=
not
x
.
domain
[
self
.
_space
].
harmonic
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
idat
=
dobj
.
local_data
(
tval
)
odat
=
np
.
empty
(
dobj
.
local_shape
(
tdom
.
shape
,
distaxis
=
distaxis
),
dtype
=
x
.
dtype
)
for
slice
in
utilities
.
get_slice_list
(
idat
.
shape
,
axes
):
odat
[
slice
]
=
func
(
idat
[
slice
])
odat
=
dobj
.
from_local_data
(
tdom
.
shape
,
odat
,
distaxis
)
if
distaxis
!=
dobj
.
distaxis
(
x
.
val
):
odat
=
dobj
.
redistribute
(
odat
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
tdom
,
odat
)
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_target
@
property
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment