Commit c2854663 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

experiment with std::any

parent 35d64442
......@@ -32,6 +32,7 @@
#include <cstddef>
#include <vector>
#include <memory>
#include <any>
class sharp_geom_info
{
......@@ -51,12 +52,9 @@ class sharp_geom_info
virtual double phi0(size_t iring) const = 0;
virtual Tpair pair(size_t ipair) const = 0;
virtual void clear_map(double *map) const = 0;
virtual void clear_map(float *map) const = 0;
virtual void get_ring(bool weighted, size_t iring, const double *map, double *ringtmp) const = 0;
virtual void get_ring(bool weighted, size_t iring, const float *map, double *ringtmp) const = 0;
virtual void add_ring(bool weighted, size_t iring, const double *ringtmp, double *map) const = 0;
virtual void add_ring(bool weighted, size_t iring, const double *ringtmp, float *map) const = 0;
virtual void clear_map(std::any map) const = 0;
virtual void get_ring(bool weighted, size_t iring, std::any map, double *ringtmp) const = 0;
virtual void add_ring(bool weighted, size_t iring, const double *ringtmp, std::any map) const = 0;
};
/*! \defgroup almgroup Helpers for dealing with a_lm */
......@@ -70,12 +68,9 @@ class sharp_alm_info
virtual size_t mmax() const = 0;
virtual size_t nm() const = 0;
virtual size_t mval(size_t i) const = 0;
virtual void clear_alm(std::complex<double> *alm) const = 0;
virtual void clear_alm(std::complex<float> *alm) const = 0;
virtual void get_alm(size_t mi, const std::complex<double> *alm, std::complex<double> *almtmp, size_t nalm) const = 0;
virtual void get_alm(size_t mi, const std::complex<float> *alm, std::complex<double> *almtmp, size_t nalm) const = 0;
virtual void add_alm(size_t mi, const std::complex<double> *almtmp, std::complex<double> *alm, size_t nalm) const = 0;
virtual void add_alm(size_t mi, const std::complex<double> *almtmp, std::complex<float> *alm, size_t nalm) const = 0;
virtual void clear_alm(std::any alm) const = 0;
virtual void get_alm(size_t mi, std::any alm, std::complex<double> *almtmp, size_t nalm) const = 0;
virtual void add_alm(size_t mi, const std::complex<double> *almtmp, std::any alm, size_t nalm) const = 0;
};
/*! \} */
......
......@@ -54,38 +54,43 @@ sharp_standard_alm_info::sharp_standard_alm_info (size_t lmax__, size_t mmax_, p
}
}
void sharp_standard_alm_info::clear_alm (dcmplx *alm) const
template<typename T> void sharp_standard_alm_info::tclear (T *alm) const
{
for (size_t mi=0;mi<mval_.size();++mi)
for (size_t l=mval_[mi];l<=lmax_;++l)
reinterpret_cast<dcmplx *>(alm)[mvstart[mi]+l*stride]=0.;
reinterpret_cast<T *>(alm)[mvstart[mi]+l*stride]=0.;
}
void sharp_standard_alm_info::clear_alm (fcmplx *alm) const
void sharp_standard_alm_info::clear_alm(std::any alm) const
{
for (size_t mi=0;mi<mval_.size();++mi)
for (size_t l=mval_[mi];l<=lmax_;++l)
reinterpret_cast<fcmplx *>(alm)[mvstart[mi]+l*stride]=0.;
if (alm.type()==typeid(dcmplx *)) tclear(any_cast<dcmplx *>(alm));
else if (alm.type()==typeid(fcmplx *)) tclear(any_cast<fcmplx *>(alm));
else MR_fail("bad a_lm data type");
}
void sharp_standard_alm_info::get_alm(size_t mi, const dcmplx *alm, dcmplx *almtmp, size_t nalm) const
template<typename T> void sharp_standard_alm_info::tget(size_t mi, const T *alm, dcmplx *almtmp, size_t nalm) const
{
for (auto l=mval_[mi]; l<=lmax_; ++l)
almtmp[nalm*l] = alm[mvstart[mi]+l*stride];
}
void sharp_standard_alm_info::get_alm(size_t mi, const fcmplx *alm, dcmplx *almtmp, size_t nalm) const
void sharp_standard_alm_info::get_alm(size_t mi, any alm, dcmplx *almtmp, size_t nalm) const
{
for (auto l=mval_[mi]; l<=lmax_; ++l)
almtmp[nalm*l] = alm[mvstart[mi]+l*stride];
if (alm.type()==typeid(dcmplx *)) tget(mi, any_cast<dcmplx *>(alm), almtmp, nalm);
else if (alm.type()==typeid(const dcmplx *)) tget(mi, any_cast<const dcmplx *>(alm), almtmp, nalm);
else if (alm.type()==typeid(fcmplx *)) tget(mi, any_cast<fcmplx *>(alm), almtmp, nalm);
else if (alm.type()==typeid(const fcmplx *)) tget(mi, any_cast<const fcmplx *>(alm), almtmp, nalm);
else MR_fail("bad a_lm data type");
}
void sharp_standard_alm_info::add_alm(size_t mi, const dcmplx *almtmp, dcmplx *alm, size_t nalm) const
template<typename T> void sharp_standard_alm_info::tadd(size_t mi, const dcmplx *almtmp, T *alm, size_t nalm) const
{
for (auto l=mval_[mi]; l<=lmax_; ++l)
alm[mvstart[mi]+l*stride] += almtmp[nalm*l];
alm[mvstart[mi]+l*stride] += T(almtmp[nalm*l]);
}
void sharp_standard_alm_info::add_alm(size_t mi, const dcmplx *almtmp, fcmplx *alm, size_t nalm) const
void sharp_standard_alm_info::add_alm(size_t mi, const dcmplx *almtmp, any alm, size_t nalm) const
{
for (auto l=mval_[mi]; l<=lmax_; ++l)
alm[mvstart[mi]+l*stride] += fcmplx(almtmp[nalm*l]);
if (alm.type()==typeid(dcmplx *)) tadd(mi, almtmp, any_cast<dcmplx *>(alm), nalm);
else if (alm.type()==typeid(fcmplx *)) tadd(mi, almtmp, any_cast<fcmplx *>(alm), nalm);
else MR_fail("bad a_lm data type");
}
ptrdiff_t sharp_standard_alm_info::index (int l, int mi)
{
return mvstart[mi]+stride*l;
......
......@@ -46,6 +46,9 @@ class sharp_standard_alm_info: public sharp_alm_info
std::vector<ptrdiff_t> mvstart;
/*! Stride between a_lm and a_(l+1),m */
ptrdiff_t stride;
template<typename T> void tclear (T *alm) const;
template<typename T> void tget (size_t mi, const T *alm, std::complex<double> *almtmp, size_t nalm) const;
template<typename T> void tadd (size_t mi, const std::complex<double> *almtmp, T *alm, size_t nalm) const;
public:
/*! Creates an a_lm data structure from the following parameters:
......@@ -80,12 +83,9 @@ class sharp_standard_alm_info: public sharp_alm_info
virtual size_t mmax() const;
virtual size_t nm() const { return mval_.size(); }
virtual size_t mval(size_t i) const { return mval_[i]; }
virtual void clear_alm(std::complex<double> *alm) const;
virtual void clear_alm(std::complex<float> *alm) const;
virtual void get_alm(size_t mi, const std::complex<double> *alm, std::complex<double> *almtmp, size_t nalm) const;
virtual void get_alm(size_t mi, const std::complex<float> *alm, std::complex<double> *almtmp, size_t nalm) const;
virtual void add_alm(size_t mi, const std::complex<double> *almtmp, std::complex<double> *alm, size_t nalm) const;
virtual void add_alm(size_t mi, const std::complex<double> *almtmp, std::complex<float> *alm, size_t nalm) const;
virtual void clear_alm(std::any alm) const;
virtual void get_alm(size_t mi, std::any alm, std::complex<double> *almtmp, size_t nalm) const;
virtual void add_alm(size_t mi, const std::complex<double> *almtmp, std::any alm, size_t nalm) const;
};
/*! Initialises an a_lm data structure according to the scheme used by
......
......@@ -86,62 +86,56 @@ sharp_standard_geom_info::sharp_standard_geom_info(size_t nrings, const size_t *
return ring[a.r1].nph<ring[b.r1].nph;
});
}
void sharp_standard_geom_info::clear_map (double *map) const
template<typename T> bool can_cast(any val)
{ return val.type()==typeid(T); }
template<typename T> void sharp_standard_geom_info::tclear(T *map) const
{
for (const auto &r: ring)
{
if (stride==1)
memset(&map[r.ofs],0,r.nph*sizeof(double));
memset(&map[r.ofs],0,r.nph*sizeof(T));
else
for (size_t i=0;i<r.nph;++i)
map[r.ofs+i*stride]=0;
map[r.ofs+i*stride]=T(0);
}
}
void sharp_standard_geom_info::clear_map (float *map) const
void sharp_standard_geom_info::clear_map (any map) const
{
for (const auto &r: ring)
{
if (stride==1)
memset(&map[r.ofs],0,r.nph*sizeof(float));
else
for (size_t i=0;i<r.nph;++i)
map[r.ofs+i*stride]=0;
}
if (can_cast<double *>(map)) tclear(any_cast<double *>(map));
else if (can_cast<float *>(map)) tclear(any_cast<float *>(map));
else MR_fail("bad map data type");
}
//virtual
void sharp_standard_geom_info::add_ring(bool weighted, size_t iring, const double *ringtmp, double *map) const
template<typename T> void sharp_standard_geom_info::tadd(bool weighted, size_t iring, const double *ringtmp, T *map) const
{
double *MRUTIL_RESTRICT p1=&map[ring[iring].ofs];
T *MRUTIL_RESTRICT p1=&map[ring[iring].ofs];
double wgt = weighted ? ring[iring].weight : 1.;
for (size_t m=0; m<ring[iring].nph; ++m)
p1[m*stride] += ringtmp[m]*wgt;
p1[m*stride] += T(ringtmp[m]*wgt);
}
//virtual
void sharp_standard_geom_info::add_ring(bool weighted, size_t iring, const double *ringtmp, float *map) const
void sharp_standard_geom_info::add_ring(bool weighted, size_t iring, const double *ringtmp, any map) const
{
float *MRUTIL_RESTRICT p1=&map[ring[iring].ofs];
double wgt = weighted ? ring[iring].weight : 1.;
for (size_t m=0; m<ring[iring].nph; ++m)
p1[m*stride] += float(ringtmp[m]*wgt);
if (can_cast<double *>(map)) tadd(weighted, iring, ringtmp, any_cast<double *>(map));
else if (can_cast<float *>(map)) tadd(weighted, iring, ringtmp, any_cast<float *>(map));
else MR_fail("bad map data type");
}
//virtual
void sharp_standard_geom_info::get_ring(bool weighted, size_t iring, const double *map, double *ringtmp) const
template<typename T> void sharp_standard_geom_info::tget(bool weighted, size_t iring, const T *map, double *ringtmp) const
{
const double *MRUTIL_RESTRICT p1=&map[ring[iring].ofs];
const T *MRUTIL_RESTRICT p1=&map[ring[iring].ofs];
double wgt = weighted ? ring[iring].weight : 1.;
for (size_t m=0; m<ring[iring].nph; ++m)
ringtmp[m] = p1[m*stride]*wgt;
}
//virtual
void sharp_standard_geom_info::get_ring(bool weighted, size_t iring, const float *map, double *ringtmp) const
void sharp_standard_geom_info::get_ring(bool weighted, size_t iring, any map, double *ringtmp) const
{
const float *MRUTIL_RESTRICT p1=&map[ring[iring].ofs];
double wgt = weighted ? ring[iring].weight : 1.;
for (size_t m=0; m<ring[iring].nph; ++m)
ringtmp[m] = p1[m*stride]*wgt;
if (can_cast<const double *>(map)) tget(weighted, iring, any_cast<const double *>(map), ringtmp);
else if (can_cast<double *>(map)) tget(weighted, iring, any_cast<double *>(map), ringtmp);
else if (can_cast<const float *>(map)) tget(weighted, iring, any_cast<const float *>(map), ringtmp);
else if (can_cast<float *>(map)) tget(weighted, iring, any_cast<float *>(map), ringtmp);
else MR_assert(false,"bad map data type",map.type().name());
}
unique_ptr<sharp_geom_info> sharp_make_subset_healpix_geom_info (size_t nside, ptrdiff_t stride, size_t nrings,
......
......@@ -44,6 +44,9 @@ class sharp_standard_geom_info: public sharp_geom_info
std::vector<Tpair> pair_;
ptrdiff_t stride;
size_t nphmax_;
template<typename T> void tclear (T *map) const;
template<typename T> void tget (bool weighted, size_t iring, const T *map, double *ringtmp) const;
template<typename T> void tadd (bool weighted, size_t iring, const double *ringtmp, T *map) const;
public:
/*! Creates a geometry information from a set of ring descriptions.
......@@ -69,12 +72,9 @@ class sharp_standard_geom_info: public sharp_geom_info
virtual double sth(size_t iring) const { return ring[iring].sth; }
virtual double phi0(size_t iring) const { return ring[iring].phi0; }
virtual Tpair pair(size_t ipair) const { return pair_[ipair]; }
virtual void clear_map(double *map) const;
virtual void clear_map(float *map) const;
virtual void get_ring(bool weighted, size_t iring, const double *map, double *ringtmp) const;
virtual void get_ring(bool weighted, size_t iring, const float *map, double *ringtmp) const;
virtual void add_ring(bool weighted, size_t iring, const double *ringtmp, double *map) const;
virtual void add_ring(bool weighted, size_t iring, const double *ringtmp, float *map) const;
virtual void clear_map(std::any map) const;
virtual void get_ring(bool weighted, size_t iring, std::any map, double *ringtmp) const;
virtual void add_ring(bool weighted, size_t iring, const double *ringtmp, std::any map) const;
};
/*! Creates a geometry information describing a HEALPix map with an
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment