Commit 401aa3f6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'mt_fixes' into 'master'

Mt fixes

See merge request mtr/cxxbase!5
parents a9a85d97 c0c8afce
......@@ -44,6 +44,9 @@ kmax=13
ncomp=1
separate=True
ncomp2 = ncomp if separate else 1
epsilon = 1e-4
ofactor = 2
nthreads = 0
# get random sky a_lm
# the a_lm arrays follow the same conventions as those in healpy
......@@ -55,8 +58,9 @@ blm = random_alm(lmax, kmax, ncomp)
t0=time.time()
# build interpolator object for slm and blm
foo = pyinterpol_ng.PyInterpolator(slm,blm,separate,lmax, kmax, epsilon=1e-4, nthreads=2)
foo = pyinterpol_ng.PyInterpolator(slm,blm,separate,lmax, kmax, epsilon=epsilon, ofactor=ofactor, nthreads=nthreads)
print("setup time: ",time.time()-t0)
print("support:",foo.support())
nth = lmax+1
nph = 2*lmax+1
......@@ -98,7 +102,7 @@ t0=time.time()
bar=foo.interpol(ptg)
print("interpolation time: ", time.time()-t0)
fake = np.random.uniform(0.,1., (ptg.shape[0],ncomp2))
foo2 = pyinterpol_ng.PyInterpolator(lmax, kmax, ncomp2, epsilon=1e-4, nthreads=2)
foo2 = pyinterpol_ng.PyInterpolator(lmax, kmax, ncomp2, epsilon=epsilon, ofactor=ofactor, nthreads=nthreads)
t0=time.time()
foo2.deinterpol(ptg.reshape((-1,3)), fake)
print("deinterpolation time: ", time.time()-t0)
......
......@@ -26,8 +26,6 @@ namespace detail_interpol_ng {
using namespace std;
constexpr double ofmin=1.5;
template<typename T> class Interpolator
{
protected:
......@@ -141,7 +139,7 @@ template<typename T> class Interpolator
public:
Interpolator(const vector<Alm<complex<T>>> &slm,
const vector<Alm<complex<T>>> &blm,
bool separate, T epsilon, int nthreads_)
bool separate, T epsilon, T ofmin, int nthreads_)
: adjoint(false),
lmax(slm.at(0).Lmax()),
kmax(blm.at(0).Mmax()),
......@@ -252,7 +250,7 @@ template<typename T> class Interpolator
}
}
Interpolator(size_t lmax_, size_t kmax_, size_t ncomp_, T epsilon, int nthreads_)
Interpolator(size_t lmax_, size_t kmax_, size_t ncomp_, T epsilon, T ofmin, int nthreads_)
: adjoint(true),
lmax(lmax_),
kmax(kmax_),
......@@ -338,6 +336,9 @@ template<typename T> class Interpolator
});
}
size_t support() const
{ return supp; }
void deinterpol (const mav<T,2> &ptg, const mav<T,2> &data)
{
MR_assert(adjoint, "can only be called in adjoint mode");
......
......@@ -45,12 +45,15 @@ void makevec_v(py::array &inp, int64_t lmax, int64_t kmax, vector<Alm<complex<T>
}
public:
PyInterpolator(const py::array &slm, const py::array &blm,
bool separate, int64_t lmax, int64_t kmax, T epsilon, int nthreads=0)
bool separate, int64_t lmax, int64_t kmax, T epsilon, T ofactor, int nthreads)
: Interpolator<T>(makevec(slm, lmax, lmax),
makevec(blm, lmax, kmax),
separate, epsilon, nthreads) {}
PyInterpolator(int64_t lmax, int64_t kmax, int64_t ncomp_, T epsilon, int nthreads=0)
: Interpolator<T>(lmax, kmax, ncomp_, epsilon, nthreads) {}
separate, epsilon, ofactor, nthreads) {}
PyInterpolator(int64_t lmax, int64_t kmax, int64_t ncomp_, T epsilon, T ofactor, int nthreads)
: Interpolator<T>(lmax, kmax, ncomp_, epsilon, ofactor, nthreads) {}
using Interpolator<T>::support;
py::array pyinterpol(const py::array &ptg) const
{
auto ptg2 = to_mav<T,2>(ptg);
......@@ -131,6 +134,11 @@ kmax : int
maximum azimuthal moment in the beam coefficients
epsilon : float
desired accuracy for the interpolation; a typical value is 1e-5
ofactor : float
oversampling factor to be used for the interpolation grids.
Should be in the range [1.2; 2], a typical value is 1.5
Increasing this factor makes (adjoint) convolution slower and
increases memory consumption, but speeds up interpolation/deinterpolation.
nthreads : the number of threads to use for computation
)""";
......@@ -148,6 +156,11 @@ ncomp : int
Can be 1 or 3.
epsilon : float
desired accuracy for the interpolation; a typical value is 1e-5
ofactor : float
oversampling factor to be used for the interpolation grids.
Should be in the range [1.2; 2], a typical value is 1.5
Increasing this factor makes (adjoint) convolution slower and
increases memory consumption, but speeds up interpolation/deinterpolation.
nthreads : the number of threads to use for computation
)""";
......@@ -230,14 +243,15 @@ PYBIND11_MODULE(pyinterpol_ng, m)
m.doc() = pyinterpol_ng_DS;
py::class_<PyInterpolator<fptype>> (m, "PyInterpolator", pyinterpolator_DS)
.def(py::init<const py::array &, const py::array &, bool, int64_t, int64_t, fptype, int>(),
initnormal_DS, "sky"_a, "beam"_a, "separate"_a, "lmax"_a, "kmax"_a, "epsilon"_a,
"nthreads"_a)
.def(py::init<int64_t, int64_t, int64_t, fptype, int>(), initadjoint_DS,
"lmax"_a, "kmax"_a, "ncomp"_a, "epsilon"_a, "nthreads"_a)
.def(py::init<const py::array &, const py::array &, bool, int64_t, int64_t, fptype, fptype, int>(),
initnormal_DS, "sky"_a, "beam"_a, "separate"_a, "lmax"_a, "kmax"_a, "epsilon"_a, "ofactor"_a=fptype(1.5),
"nthreads"_a=0)
.def(py::init<int64_t, int64_t, int64_t, fptype, fptype, int>(), initadjoint_DS,
"lmax"_a, "kmax"_a, "ncomp"_a, "epsilon"_a, "ofactor"_a=fptype(1.5), "nthreads"_a=0)
.def ("interpol", &PyInterpolator<fptype>::pyinterpol, interpol_DS, "ptg"_a)
.def ("deinterpol", &PyInterpolator<fptype>::pydeinterpol, deinterpol_DS, "ptg"_a, "data"_a)
.def ("getSlm", &PyInterpolator<fptype>::pygetSlm, getSlm_DS, "beam"_a);
.def ("getSlm", &PyInterpolator<fptype>::pygetSlm, getSlm_DS, "beam"_a)
.def ("support", &PyInterpolator<fptype>::support);
#if 1
m.def("rotate_alm", &pyrotate_alm<fptype>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
"phi"_a);
......
......@@ -113,8 +113,16 @@ template<typename T> class membuf
: ptr(make_unique<vector<T>>(sz)), d(ptr->data()), rw(true) {}
membuf(const membuf &other)
: ptr(other.ptr), d(other.d), rw(false) {}
#if defined(_MSC_VER)
// MSVC is broken
membuf(membuf &other)
: ptr(other.ptr), d(other.d), rw(other.rw) {}
membuf(membuf &&other)
: ptr(move(other.ptr)), d(move(other.d)), rw(move(other.rw)) {}
#else
membuf(membuf &other) = default;
membuf(membuf &&other) = default;
#endif
template<typename I> T &vraw(I i)
{
......@@ -151,9 +159,16 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
: fmav_info(info), membuf<T>(d_, rw_) {}
fmav(const T* d_, const fmav_info &info)
: fmav_info(info), membuf<T>(d_) {}
#if defined(_MSC_VER)
// MSVC is broken
fmav(const fmav &other) : fmav_info(other), membuf<T>(other) {};
fmav(fmav &other) : fmav_info(other), membuf<T>(other) {}
fmav(fmav &&other) : fmav_info(other), membuf<T>(other) {}
#else
fmav(const fmav &other) = default;
fmav(fmav &other) = default;
fmav(fmav &&other) = default;
#endif
fmav(membuf<T> &buf, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const membuf<T> &buf, const shape_t &shp_, const stride_t &str_)
......@@ -318,9 +333,16 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
: mav_info<ndim>(shp_), membuf<T>(d_, rw_) {}
mav(const array<size_t,ndim> &shp_)
: mav_info<ndim>(shp_), membuf<T>(size()) {}
#if defined(_MSC_VER)
// MSVC is broken
mav(const mav &other) : mav_info<ndim>(other), membuf<T>(other) {}
mav(mav &other): mav_info<ndim>(other), membuf<T>(other) {}
mav(mav &&other): mav_info<ndim>(other), membuf<T>(other) {}
#else
mav(const mav &other) = default;
mav(mav &other) = default;
mav(mav &&other) = default;
#endif
mav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf<T> &mb)
: mav_info<ndim>(shp_, str_), membuf<T>(d_, mb) {}
mav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf<T> &mb)
......
......@@ -26,6 +26,7 @@
*/
#include <cmath>
#include <algorithm>
#include <atomic>
#include <memory>
#include "mr_util/math/math_utils.h"
......
......@@ -25,6 +25,7 @@
* \author Martin Reinecke
*/
#include <algorithm>
#include <cmath>
#include <vector>
#include "mr_util/sharp/sharp_geomhelpers.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment