Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ed3a2d8c
Commit
ed3a2d8c
authored
Mar 20, 2018
by
Martin Reinecke
Browse files
Merge branch 'simplify_htops' into 'NIFTy_4'
put SHTs into their own operator class See merge request ift/NIFTy!234
parents
aff5cebf
1883e5d9
Pipeline
#26302
passed with stages
in 5 minutes and 39 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty4/operators/harmonic_transform_operator.py
View file @
ed3a2d8c
...
...
@@ -16,23 +16,20 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..domains.rg_space
import
RGSpace
from
.linear_operator
import
LinearOperator
from
..
import
dobj
from
.fft_operator
import
FFTOperator
from
.sht_operator
import
SHTOperator
from
..
import
utilities
from
..field
import
Field
from
..domains.gl_space
import
GLSpace
class
HarmonicTransformOperator
(
LinearOperator
):
"""Transforms between a harmonic domain and a position domain counterpart.
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances)
- an LMSpace and a
LM
Space
- an LMSpace and a
HP
Space
- an LMSpace and a GLSpace
The supported operations are times() and adjoint_times().
...
...
@@ -57,147 +54,29 @@ class HarmonicTransformOperator(LinearOperator):
def
__init__
(
self
,
domain
,
target
=
None
,
space
=
None
):
super
(
HarmonicTransformOperator
,
self
).
__init__
()
# Initialize domain and target
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
domain
=
DomainTuple
.
make
(
domain
)
space
=
utilities
.
infer_space
(
domain
,
space
)
hspc
=
self
.
_
domain
[
self
.
_
space
]
hspc
=
domain
[
space
]
if
not
hspc
.
harmonic
:
raise
TypeError
(
"HarmonicTransformOperator only works on a harmonic space"
)
if
target
is
None
:
target
=
hspc
.
get_default_codomain
()
self
.
_target
=
[
dom
for
dom
in
self
.
_domain
]
self
.
_target
[
self
.
_space
]
=
target
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
hspc
.
check_codomain
(
target
)
target
.
check_codomain
(
hspc
)
if
isinstance
(
hspc
,
RGSpace
):
self
.
_applyfunc
=
self
.
_apply_cartesian
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
self
.
_op
=
FFTOperator
(
domain
,
target
,
space
)
else
:
from
pyHealpix
import
sharpjob_d
self
.
_applyfunc
=
self
.
_apply_spherical
self
.
lmax
=
hspc
.
lmax
self
.
mmax
=
hspc
.
mmax
self
.
sjob
=
sharpjob_d
()
self
.
sjob
.
set_triangular_alm_info
(
self
.
lmax
,
self
.
mmax
)
if
isinstance
(
target
,
GLSpace
):
self
.
sjob
.
set_Gauss_geometry
(
target
.
nlat
,
target
.
nlon
)
else
:
self
.
sjob
.
set_Healpix_geometry
(
target
.
nside
)
self
.
_op
=
SHTOperator
(
domain
,
target
,
space
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
np
.
issubdtype
(
x
.
dtype
,
np
.
complexfloating
):
return
(
self
.
_applyfunc
(
x
.
real
,
mode
)
+
1j
*
self
.
_applyfunc
(
x
.
imag
,
mode
))
else
:
return
self
.
_applyfunc
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
if
oldax
not
in
axes
:
# straightforward, no redistribution needed
ldat
=
x
.
local_data
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat
,
distaxis
=
oldax
)
elif
len
(
axes
)
<
len
(
x
.
shape
)
or
len
(
axes
)
==
1
:
# we can use one Hartley pass in between the redistributions
tmp
=
dobj
.
redistribute
(
x
.
val
,
nodist
=
axes
)
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
else
:
# two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
rem_axes
=
tuple
(
i
for
i
in
axes
if
i
!=
oldax
)
tmp
=
x
.
val
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
my_fftn_r2c
(
ldat
,
axes
=
rem_axes
)
if
oldax
!=
0
:
raise
ValueError
(
"bad distribution"
)
ldat2
=
ldat
.
reshape
((
ldat
.
shape
[
0
],
np
.
prod
(
ldat
.
shape
[
1
:])))
shp2d
=
(
x
.
val
.
shape
[
0
],
np
.
prod
(
x
.
val
.
shape
[
1
:]))
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
).
reshape
(
ldat
.
shape
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat2
,
distaxis
=
0
)
Tval
=
Field
(
tdom
,
tmp
)
fct
=
self
.
_domain
[
self
.
_space
].
scalar_dvol
if
fct
!=
1
:
Tval
*=
fct
return
Tval
def
_slice_p2h
(
self
,
inp
):
rr
=
self
.
sjob
.
alm2map_adjoint
(
inp
)
if
len
(
rr
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
=
np
.
empty
(
2
*
len
(
rr
)
-
self
.
lmax
-
1
,
dtype
=
rr
[
0
].
real
.
dtype
)
res
[
0
:
self
.
lmax
+
1
]
=
rr
[
0
:
self
.
lmax
+
1
].
real
res
[
self
.
lmax
+
1
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
real
res
[
self
.
lmax
+
2
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
imag
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_slice_h2p
(
self
,
inp
):
res
=
np
.
empty
((
len
(
inp
)
+
self
.
lmax
+
1
)
//
2
,
dtype
=
(
inp
[
0
]
*
1j
).
dtype
)
if
len
(
res
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
[
0
:
self
.
lmax
+
1
]
=
inp
[
0
:
self
.
lmax
+
1
]
res
[
self
.
lmax
+
1
:]
=
np
.
sqrt
(
0.5
)
*
(
inp
[
self
.
lmax
+
1
::
2
]
+
1j
*
inp
[
self
.
lmax
+
2
::
2
])
res
=
self
.
sjob
.
alm2map
(
res
)
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_apply_spherical
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
tval
=
x
.
val
if
dobj
.
distaxis
(
tval
)
==
axis
:
tval
=
dobj
.
redistribute
(
tval
,
nodist
=
(
axis
,))
distaxis
=
dobj
.
distaxis
(
tval
)
p2h
=
not
x
.
domain
[
self
.
_space
].
harmonic
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
idat
=
dobj
.
local_data
(
tval
)
odat
=
np
.
empty
(
dobj
.
local_shape
(
tdom
.
shape
,
distaxis
=
distaxis
),
dtype
=
x
.
dtype
)
for
slice
in
utilities
.
get_slice_list
(
idat
.
shape
,
axes
):
odat
[
slice
]
=
func
(
idat
[
slice
])
odat
=
dobj
.
from_local_data
(
tdom
.
shape
,
odat
,
distaxis
)
if
distaxis
!=
dobj
.
distaxis
(
x
.
val
):
odat
=
dobj
.
redistribute
(
odat
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
tdom
,
odat
)
return
self
.
_op
.
apply
(
x
,
mode
)
@
property
def
domain
(
self
):
return
self
.
_domain
return
self
.
_
op
.
domain
@
property
def
target
(
self
):
return
self
.
_target
return
self
.
_
op
.
target
@
property
def
capability
(
self
):
...
...
nifty4/operators/sht_operator.py
0 → 100644
View file @
ed3a2d8c
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
.linear_operator
import
LinearOperator
from
..
import
dobj
from
..
import
utilities
from
..field
import
Field
from
..domains.gl_space
import
GLSpace
from
..domains.lm_space
import
LMSpace
class
SHTOperator
(
LinearOperator
):
"""Transforms between a harmonic domain on the sphere and a position
domain counterpart.
Built-in domain pairs are
- an LMSpace and a HPSpace
- an LMSpace and a GLSpace
The supported operations are times() and adjoint_times().
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times".
target : Domain, optional
The target domain of the transform operation.
If omitted, a domain will be chosen automatically.
Whenever the input domain of the transform is an RGSpace, the codomain
(and its parameters) are uniquely determined.
For LMSpace, a GLSpace of sufficient resolution is chosen.
space : int, optional
The index of the domain on which the operator should act
If None, it is set to 0 if domain contains exactly one subdomain.
domain[space] must be a LMSpace.
"""
def
__init__
(
self
,
domain
,
target
=
None
,
space
=
None
):
super
(
SHTOperator
,
self
).
__init__
()
# Initialize domain and target
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
hspc
=
self
.
_domain
[
self
.
_space
]
if
not
isinstance
(
hspc
,
LMSpace
):
raise
TypeError
(
"SHTOperator only works on a LMSpace domain"
)
if
target
is
None
:
target
=
hspc
.
get_default_codomain
()
self
.
_target
=
[
dom
for
dom
in
self
.
_domain
]
self
.
_target
[
self
.
_space
]
=
target
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
hspc
.
check_codomain
(
target
)
target
.
check_codomain
(
hspc
)
from
pyHealpix
import
sharpjob_d
self
.
lmax
=
hspc
.
lmax
self
.
mmax
=
hspc
.
mmax
self
.
sjob
=
sharpjob_d
()
self
.
sjob
.
set_triangular_alm_info
(
self
.
lmax
,
self
.
mmax
)
if
isinstance
(
target
,
GLSpace
):
self
.
sjob
.
set_Gauss_geometry
(
target
.
nlat
,
target
.
nlon
)
else
:
self
.
sjob
.
set_Healpix_geometry
(
target
.
nside
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
np
.
issubdtype
(
x
.
dtype
,
np
.
complexfloating
):
return
(
self
.
_apply_spherical
(
x
.
real
,
mode
)
+
1j
*
self
.
_apply_spherical
(
x
.
imag
,
mode
))
else
:
return
self
.
_apply_spherical
(
x
,
mode
)
def
_slice_p2h
(
self
,
inp
):
rr
=
self
.
sjob
.
alm2map_adjoint
(
inp
)
if
len
(
rr
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
=
np
.
empty
(
2
*
len
(
rr
)
-
self
.
lmax
-
1
,
dtype
=
rr
[
0
].
real
.
dtype
)
res
[
0
:
self
.
lmax
+
1
]
=
rr
[
0
:
self
.
lmax
+
1
].
real
res
[
self
.
lmax
+
1
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
real
res
[
self
.
lmax
+
2
::
2
]
=
np
.
sqrt
(
2
)
*
rr
[
self
.
lmax
+
1
:].
imag
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_slice_h2p
(
self
,
inp
):
res
=
np
.
empty
((
len
(
inp
)
+
self
.
lmax
+
1
)
//
2
,
dtype
=
(
inp
[
0
]
*
1j
).
dtype
)
if
len
(
res
)
!=
((
self
.
mmax
+
1
)
*
(
self
.
mmax
+
2
))
//
2
+
\
(
self
.
mmax
+
1
)
*
(
self
.
lmax
-
self
.
mmax
):
raise
ValueError
(
"array length mismatch"
)
res
[
0
:
self
.
lmax
+
1
]
=
inp
[
0
:
self
.
lmax
+
1
]
res
[
self
.
lmax
+
1
:]
=
np
.
sqrt
(
0.5
)
*
(
inp
[
self
.
lmax
+
1
::
2
]
+
1j
*
inp
[
self
.
lmax
+
2
::
2
])
res
=
self
.
sjob
.
alm2map
(
res
)
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_apply_spherical
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
tval
=
x
.
val
if
dobj
.
distaxis
(
tval
)
==
axis
:
tval
=
dobj
.
redistribute
(
tval
,
nodist
=
(
axis
,))
distaxis
=
dobj
.
distaxis
(
tval
)
p2h
=
not
x
.
domain
[
self
.
_space
].
harmonic
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
idat
=
dobj
.
local_data
(
tval
)
odat
=
np
.
empty
(
dobj
.
local_shape
(
tdom
.
shape
,
distaxis
=
distaxis
),
dtype
=
x
.
dtype
)
for
slice
in
utilities
.
get_slice_list
(
idat
.
shape
,
axes
):
odat
[
slice
]
=
func
(
idat
[
slice
])
odat
=
dobj
.
from_local_data
(
tdom
.
shape
,
odat
,
distaxis
)
if
distaxis
!=
dobj
.
distaxis
(
x
.
val
):
odat
=
dobj
.
redistribute
(
odat
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
tdom
,
odat
)
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_target
@
property
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment