Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Simon Perkins
ducc
Commits
00384bd7
Commit
00384bd7
authored
Mar 08, 2020
by
Martin Reinecke
Browse files
halfway there
parent
ae4c6c2f
Changes
3
Hide whitespace changes
Inline
Side-by-side
interpol_ng/demo.py
View file @
00384bd7
...
...
@@ -4,7 +4,7 @@ import pysharp
import
time
import
matplotlib.pyplot
as
plt
np
.
random
.
seed
(
48
)
np
.
random
.
seed
(
20
)
def
nalm
(
lmax
,
mmax
):
return
((
mmax
+
1
)
*
(
mmax
+
2
))
//
2
+
(
mmax
+
1
)
*
(
lmax
-
mmax
)
...
...
@@ -32,9 +32,9 @@ def convolve(alm1, alm2, lmax):
job
.
set_triangular_alm_info
(
0
,
0
)
return
job
.
map2alm
(
map
)[
0
]
*
np
.
sqrt
(
4
*
np
.
pi
)
lmax
=
1
0
lmax
=
4
0
mmax
=
lmax
kmax
=
lmax
kmax
=
10
# get random sky a_lm
...
...
@@ -42,15 +42,15 @@ kmax=lmax
slmT
=
random_alm
(
lmax
,
mmax
)
# build beam a_lm
blmT
=
random_alm
(
lmax
,
m
max
)
blmT
=
random_alm
(
lmax
,
k
max
)
t0
=
time
.
time
()
# build interpolator object for slmT and blmT
foo
=
interpol_ng
.
PyInterpolator
(
slmT
,
blmT
,
lmax
,
kmax
,
epsilon
=
1e-6
,
nthreads
=
2
)
foo
=
interpol_ng
.
PyInterpolator
(
slmT
,
blmT
,
lmax
,
kmax
,
epsilon
=
1e-6
,
nthreads
=
1
)
print
(
"setup time: "
,
time
.
time
()
-
t0
)
nth
=
2
*
lmax
+
1
nph
=
2
*
mmax
+
1
#exit()
ptg
=
np
.
zeros
((
nth
,
nph
,
3
))
ptg
[:,:,
0
]
=
(
np
.
pi
*
np
.
arange
(
nth
)
/
(
nth
-
1
)).
reshape
((
-
1
,
1
))
ptg
[:,:,
1
]
=
(
2
*
np
.
pi
*
np
.
arange
(
nph
)
/
nph
).
reshape
((
1
,
-
1
))
...
...
@@ -62,12 +62,24 @@ print("interpolation time: ", time.time()-t0)
plt
.
subplot
(
2
,
2
,
1
)
plt
.
imshow
(
bar
.
reshape
((
nth
,
nph
)))
bar2
=
np
.
zeros
((
nth
,
nph
))
blmTfull
=
np
.
zeros
(
slmT
.
size
)
+
0j
blmTfull
[
0
:
blmT
.
size
]
=
blmT
for
ith
in
range
(
nth
):
for
iph
in
range
(
nph
):
rbeam
=
interpol_ng
.
rotate_alm
(
blmT
,
lmax
,
ptg
[
ith
,
iph
,
2
],
ptg
[
ith
,
iph
,
0
],
ptg
[
ith
,
iph
,
1
])
rbeam
=
interpol_ng
.
rotate_alm
(
blmT
full
,
lmax
,
ptg
[
ith
,
iph
,
2
],
ptg
[
ith
,
iph
,
0
],
ptg
[
ith
,
iph
,
1
])
bar2
[
ith
,
iph
]
=
convolve
(
slmT
,
rbeam
,
lmax
).
real
plt
.
subplot
(
2
,
2
,
2
)
plt
.
imshow
(
bar2
)
plt
.
subplot
(
2
,
2
,
3
)
plt
.
imshow
((
bar2
-
bar
.
reshape
((
nth
,
nph
))))
plt
.
show
()
fake
=
np
.
random
.
uniform
(
-
1.
,
1.
,
bar
.
size
)
foo2
=
interpol_ng
.
PyInterpolator
(
lmax
,
kmax
,
epsilon
=
1e-6
,
nthreads
=
2
)
foo2
.
deinterpol
(
ptg
.
reshape
((
-
1
,
3
)),
fake
)
bla
=
foo2
.
getSlm
(
blmT
)
print
(
np
.
vdot
(
slmT
,
bla
))
slmT
[
lmax
+
1
:]
*=
np
.
sqrt
(
2
)
bla
[
lmax
+
1
:]
*=
np
.
sqrt
(
2
)
print
(
np
.
vdot
(
slmT
,
bla
))
print
(
np
.
vdot
(
fake
,
bar
))
interpol_ng/interpol_ng.cc
View file @
00384bd7
...
...
@@ -17,7 +17,7 @@
#include
"alm.h"
#include
"mr_util/math/fft.h"
#include
"mr_util/bindings/pybind_utils.h"
#include
<iostream>
using
namespace
std
;
using
namespace
mr
;
...
...
@@ -42,6 +42,7 @@ template<typename T> class Interpolator
{
double
sfct
=
(
spin
&
1
)
?
-
1
:
1
;
mav
<
T
,
2
>
tmp
({
nphi
,
nphi
});
fmav
<
T
>
ftmp
(
tmp
);
tmp
.
apply
([](
T
&
v
){
v
=
0.
;});
auto
tmp0
=
tmp
.
template
subarray
<
2
>({
0
,
0
},{
nphi0
,
nphi0
});
fmav
<
T
>
ftmp0
(
tmp0
);
...
...
@@ -49,7 +50,7 @@ template<typename T> class Interpolator
for
(
size_t
j
=
0
;
j
<
nphi0
;
++
j
)
tmp0
.
v
(
i
,
j
)
=
arr
(
i
,
j
);
// extend to second half
for
(
size_t
i
=
1
,
i2
=
2
*
ntheta
0
-
3
;
i
+
1
<
ntheta0
;
++
i
,
--
i2
)
for
(
size_t
i
=
1
,
i2
=
nphi
0
-
1
;
i
+
1
<
ntheta0
;
++
i
,
--
i2
)
for
(
size_t
j
=
0
,
j2
=
nphi0
/
2
;
j
<
nphi0
;
++
j
,
++
j2
)
{
if
(
j2
>=
nphi0
)
j2
-=
nphi0
;
...
...
@@ -57,19 +58,19 @@ template<typename T> class Interpolator
}
// FFT to frequency domain on minimal grid
r2r_fftpack
(
ftmp0
,
ftmp0
,{
1
,
0
},
true
,
true
,
1.
/
(
nphi0
*
nphi0
),
nthreads
);
for
(
size_t
i
=
0
;
i
<
nphi0
;
++
i
)
{
tmp0
.
v
(
i
,
nphi0
-
1
)
*=
0.5
;
tmp0
.
v
(
nphi0
-
1
,
i
)
*=
0.5
;
}
auto
fct
=
kernel
.
correction_factors
(
nphi
,
nphi0
/
2
+
1
,
nthreads
);
for
(
size_t
i
=
0
;
i
<
nphi0
;
++
i
)
for
(
size_t
j
=
0
;
j
<
nphi0
;
++
j
)
tmp0
.
v
(
i
,
j
)
*=
fct
[(
i
+
1
)
/
2
]
*
fct
[(
j
+
1
)
/
2
];
auto
tmp1
=
tmp
.
template
subarray
<
2
>({
0
,
0
},{
nphi
,
nphi0
});
fmav
<
T
>
ftmp1
(
tmp1
);
// zero-padded FFT in theta direction
r2r_fftpack
(
ftmp1
,
ftmp1
,{
0
},
false
,
false
,
1.
,
nthreads
);
auto
tmp2
=
tmp
.
template
subarray
<
2
>({
0
,
0
},{
ntheta
,
nphi
});
fmav
<
T
>
ftmp2
(
tmp2
);
fmav
<
T
>
farr
(
arr
);
// zero-padded FFT in phi direction
r2r_fftpack
(
ftmp2
,
farr
,{
1
},
false
,
false
,
1.
,
nthreads
);
r2r_fftpack
(
ftmp
,
ftmp
,{
0
,
1
},
false
,
false
,
1.
,
nthreads
);
for
(
size_t
i
=
0
;
i
<
ntheta
;
++
i
)
for
(
size_t
j
=
0
;
j
<
nphi
;
++
j
)
arr
.
v
(
i
,
j
)
=
tmp
(
i
,
j
);
}
void
decorrect
(
mav
<
T
,
2
>
&
arr
,
int
spin
)
{
...
...
@@ -81,7 +82,7 @@ template<typename T> class Interpolator
for
(
size_t
j
=
0
;
j
<
nphi
;
++
j
)
tmp
.
v
(
i
,
j
)
=
arr
(
i
,
j
);
// extend to second half
for
(
size_t
i
=
1
,
i2
=
2
*
ntheta
-
3
;
i
+
1
<
ntheta
;
++
i
,
--
i2
)
for
(
size_t
i
=
1
,
i2
=
nphi
-
1
;
i
+
1
<
ntheta
;
++
i
,
--
i2
)
for
(
size_t
j
=
0
,
j2
=
nphi
/
2
;
j
<
nphi
;
++
j
,
++
j2
)
{
if
(
j2
>=
nphi
)
j2
-=
nphi
;
...
...
@@ -95,7 +96,8 @@ template<typename T> class Interpolator
for
(
size_t
j
=
0
;
j
<
nphi0
;
++
j
)
tmp0
.
v
(
i
,
j
)
*=
fct
[(
i
+
1
)
/
2
]
*
fct
[(
j
+
1
)
/
2
];
// FFT to (theta, phi) domain on minimal grid
r2r_fftpack
(
ftmp0
,
ftmp0
,{
0
,
1
},
false
,
false
,
1.
/
(
nphi0
*
nphi0
),
nthreads
);
r2r_fftpack
(
ftmp0
,
ftmp0
,{
1
,
0
},
false
,
false
,
1.
/
(
nphi0
*
nphi0
),
nthreads
);
arr
.
apply
([](
T
&
v
){
v
=
0.
;});
for
(
size_t
i
=
0
;
i
<
ntheta0
;
++
i
)
for
(
size_t
j
=
0
;
j
<
nphi0
;
++
j
)
arr
.
v
(
i
,
j
)
=
tmp0
(
i
,
j
);
...
...
@@ -394,6 +396,12 @@ template<typename T> class PyInterpolator: public Interpolator<T>
using
Interpolator
<
T
>::
getSlmx
;
using
Interpolator
<
T
>::
lmax
;
using
Interpolator
<
T
>::
kmax
;
using
Interpolator
<
T
>::
nphi
;
using
Interpolator
<
T
>::
ntheta
;
using
Interpolator
<
T
>::
nphi0
;
using
Interpolator
<
T
>::
ntheta0
;
using
Interpolator
<
T
>::
correct
;
using
Interpolator
<
T
>::
decorrect
;
py
::
array
interpol
(
const
py
::
array
&
ptg
)
const
{
auto
ptg2
=
to_mav
<
T
,
2
>
(
ptg
);
...
...
@@ -419,6 +427,39 @@ slmT_.apply([](complex<T> &v){v=0;});
getSlmx
(
blmT
,
slmT
);
return
res
;
}
py
::
array
test_correct
(
const
py
::
array
&
in
,
int
spin
)
{
auto
in2
=
to_mav
<
T
,
2
>
(
in
);
MR_assert
(
in2
.
conformable
({
ntheta0
,
nphi0
}),
"bad input shape"
);
auto
res
=
make_Pyarr
<
T
>
({
ntheta
,
nphi
});
auto
res2
=
to_mav
<
T
,
2
>
(
res
,
true
);
res2
.
apply
([](
T
&
v
){
v
=
0
;});
for
(
size_t
i
=
0
;
i
<
ntheta0
;
++
i
)
for
(
size_t
j
=
0
;
j
<
nphi0
;
++
j
)
res2
.
v
(
i
,
j
)
=
in2
(
i
,
j
);
correct
(
res2
,
spin
);
return
res
;
}
py
::
array
test_decorrect
(
const
py
::
array
&
in
,
int
spin
)
{
auto
in2
=
to_mav
<
T
,
2
>
(
in
);
MR_assert
(
in2
.
conformable
({
ntheta
,
nphi
}),
"bad input shape"
);
auto
tmp
=
mav
<
T
,
2
>
({
ntheta
,
nphi
});
for
(
size_t
i
=
0
;
i
<
ntheta
;
++
i
)
for
(
size_t
j
=
0
;
j
<
nphi
;
++
j
)
tmp
.
v
(
i
,
j
)
=
in2
(
i
,
j
);
decorrect
(
tmp
,
spin
);
auto
res
=
make_Pyarr
<
T
>
({
ntheta0
,
nphi0
});
auto
res2
=
to_mav
<
T
,
2
>
(
res
,
true
);
for
(
size_t
i
=
0
;
i
<
ntheta0
;
++
i
)
for
(
size_t
j
=
0
;
j
<
nphi0
;
++
j
)
res2
.
v
(
i
,
j
)
=
tmp
(
i
,
j
);
return
res
;
}
int
Nphi0
()
const
{
return
nphi0
;
}
int
Ntheta0
()
const
{
return
ntheta0
;
}
int
Nphi
()
const
{
return
nphi
;
}
int
Ntheta
()
const
{
return
ntheta
;
}
};
#if 1
...
...
@@ -448,7 +489,13 @@ PYBIND11_MODULE(interpol_ng, m)
"lmax"
_a
,
"kmax"
_a
,
"epsilon"
_a
,
"nthreads"
_a
)
.
def
(
"interpol"
,
&
PyInterpolator
<
double
>::
interpol
,
"ptg"
_a
)
.
def
(
"deinterpol"
,
&
PyInterpolator
<
double
>::
deinterpol
,
"ptg"
_a
,
"data"
_a
)
.
def
(
"getSlm"
,
&
PyInterpolator
<
double
>::
getSlm
,
"blmT"
_a
);
.
def
(
"getSlm"
,
&
PyInterpolator
<
double
>::
getSlm
,
"blmT"
_a
)
.
def
(
"test_correct"
,
&
PyInterpolator
<
double
>::
test_correct
,
"in"
_a
,
"spin"
_a
)
.
def
(
"test_decorrect"
,
&
PyInterpolator
<
double
>::
test_decorrect
,
"in"
_a
,
"spin"
_a
)
.
def
(
"Nphi"
,
&
PyInterpolator
<
double
>::
Nphi
)
.
def
(
"Ntheta"
,
&
PyInterpolator
<
double
>::
Ntheta
)
.
def
(
"Nphi0"
,
&
PyInterpolator
<
double
>::
Nphi0
)
.
def
(
"Ntheta0"
,
&
PyInterpolator
<
double
>::
Ntheta0
);
#if 1
m
.
def
(
"rotate_alm"
,
&
pyrotate_alm
<
double
>
,
"alm"
_a
,
"lmax"
_a
,
"psi"
_a
,
"theta"
_a
,
"phi"
_a
);
...
...
src/mr_util/infra/mav.h
View file @
00384bd7
...
...
@@ -213,6 +213,8 @@ template<size_t ndim> class mav_info
}
bool
conformable
(
const
mav_info
&
other
)
const
{
return
shp
==
other
.
shp
;
}
bool
conformable
(
const
shape_t
&
other
)
const
{
return
shp
==
other
;
}
template
<
typename
...
Ns
>
ptrdiff_t
idx
(
Ns
...
ns
)
const
{
static_assert
(
ndim
==
sizeof
...(
ns
),
"incorrect number of indices"
);
...
...
@@ -231,7 +233,6 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using
membuf
<
T
>::
ptr
;
using
mav_info
<
ndim
>::
shp
;
using
mav_info
<
ndim
>::
str
;
using
mav_info
<
ndim
>::
conformable
;
using
membuf
<
T
>::
rw
;
using
membuf
<
T
>::
vraw
;
...
...
@@ -305,6 +306,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using
mav_info
<
ndim
>::
contiguous
;
using
mav_info
<
ndim
>::
size
;
using
mav_info
<
ndim
>::
idx
;
using
mav_info
<
ndim
>::
conformable
;
mav
(
const
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
mav_info
<
ndim
>
(
shp_
,
str_
),
membuf
<
T
>
(
d_
)
{}
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment