Commit 316398dc authored by Martin Reinecke's avatar Martin Reinecke
Browse files

rework

parent 4af2d1d2
...@@ -2343,15 +2343,14 @@ template<size_t N> class multi_iter ...@@ -2343,15 +2343,14 @@ template<size_t N> class multi_iter
} }
public: public:
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_,
size_t nshares, size_t myshare)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
idim(idim_), rem(iarr.size()/iarr.shape(idim)) idim(idim_), rem(iarr.size()/iarr.shape(idim))
{ {
auto nshares = num_threads();
if (nshares==1) return; if (nshares==1) return;
if (nshares==0) throw runtime_error("can't run with zero threads"); if (nshares==0) throw runtime_error("can't run with zero threads");
auto myshare = thread_num();
if (myshare>=nshares) throw runtime_error("impossible share requested"); if (myshare>=nshares) throw runtime_error("impossible share requested");
size_t nbase = rem/nshares; size_t nbase = rem/nshares;
size_t additional = rem%nshares; size_t additional = rem%nshares;
...@@ -2596,11 +2595,11 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out, ...@@ -2596,11 +2595,11 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
execParallel( execParallel(
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val), util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
[&] { [&](Scheduler &sched) {
constexpr auto vlen = VLEN<T0>::val; constexpr auto vlen = VLEN<T0>::val;
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out); const auto &tin(iax==0? in : out);
multi_iter<vlen> it(tin, out, axes[iax]); multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
...@@ -2703,10 +2702,10 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -2703,10 +2702,10 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
size_t len=in.shape(axis); size_t len=in.shape(axis);
execParallel( execParallel(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&] { [&](Scheduler &sched) {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis); multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
...@@ -2758,10 +2757,10 @@ template<typename T> MRUTIL_NOINLINE void general_c2r( ...@@ -2758,10 +2757,10 @@ template<typename T> MRUTIL_NOINLINE void general_c2r(
size_t len=out.shape(axis); size_t len=out.shape(axis);
execParallel( execParallel(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&] { [&](Scheduler &sched) {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis); multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
......
...@@ -35,7 +35,7 @@ using namespace std; ...@@ -35,7 +35,7 @@ using namespace std;
class GL_Integrator class GL_Integrator
{ {
private: private:
int m; int n_;
vector<double> x, w; vector<double> x, w;
static inline double one_minus_x2 (double x) static inline double one_minus_x2 (double x)
...@@ -43,11 +43,12 @@ class GL_Integrator ...@@ -43,11 +43,12 @@ class GL_Integrator
public: public:
GL_Integrator(int n, size_t nthreads=1) GL_Integrator(int n, size_t nthreads=1)
n_(n)
{ {
MR_assert(n>=1, "number of points must be at least 1"); MR_assert(n>=1, "number of points must be at least 1");
constexpr double pi = 3.141592653589793238462643383279502884197; constexpr double pi = 3.141592653589793238462643383279502884197;
constexpr double eps = 3e-14; constexpr double eps = 3e-14;
m = (n+1)>>1; int m = (n+1)>>1;
x.resize(m); x.resize(m);
w.resize(m); w.resize(m);
...@@ -99,7 +100,7 @@ class GL_Integrator ...@@ -99,7 +100,7 @@ class GL_Integrator
template<typename Func> double integrate(Func f) template<typename Func> double integrate(Func f)
{ {
double res=0, istart=0; double res=0, istart=0;
if (x[0]==0.) if (n_&1)
{ {
res = f(x[0])*w[0]; res = f(x[0])*w[0];
istart=1; istart=1;
...@@ -113,11 +114,30 @@ class GL_Integrator ...@@ -113,11 +114,30 @@ class GL_Integrator
{ {
using T = decltype(f(0.)); using T = decltype(f(0.));
T res=f(x[0])*w[0]; T res=f(x[0])*w[0];
if (x[0]==0.) res *= 0.5; if (n_&1) res *= 0.5;
for (size_t i=1; i<x.size(); ++i) for (size_t i=1; i<x.size(); ++i)
res += f(x[i])*w[i]; res += f(x[i])*w[i];
return res*2; return res*2;
} }
vector<double> coords() const
{
vector<double> res(n_);
for (size_t i=0; i<x.size(); ++i)
{
res[i]=-x[x.size()-1-i];
res[n_-1-i] = x[x.size()-1-i];
}
return res;
}
vector<double> weights() const
{
vector<double> res(n_);
for (size_t i=0; i<w.size(); ++i)
res[i]=res[n_-1-i]=w[w.size()-1-i];
return res;
}
}; };
} }
......
...@@ -43,8 +43,6 @@ namespace detail_threading { ...@@ -43,8 +43,6 @@ namespace detail_threading {
using namespace std; using namespace std;
#ifndef MRUTIL_NO_THREADING #ifndef MRUTIL_NO_THREADING
thread_local size_t thread_id = 0;
thread_local size_t num_threads_ = 1;
static const size_t max_threads_ = max(1u, thread::hardware_concurrency()); static const size_t max_threads_ = max(1u, thread::hardware_concurrency());
class latch class latch
...@@ -174,7 +172,7 @@ class thread_pool ...@@ -174,7 +172,7 @@ class thread_pool
} }
}; };
thread_pool & get_pool() inline thread_pool &get_pool()
{ {
static thread_pool pool; static thread_pool pool;
#if __has_include(<pthread.h>) #if __has_include(<pthread.h>)
...@@ -192,41 +190,15 @@ thread_pool & get_pool() ...@@ -192,41 +190,15 @@ thread_pool & get_pool()
return pool; return pool;
} }
/** Map a function f over nthreads */ struct Range
template <typename Func>
void thread_map(size_t nthreads, Func f)
{ {
if (nthreads == 0) size_t lo, hi;
nthreads = max_threads_; Range() : lo(0), hi(0) {}
Range(size_t lo_, size_t hi_) : lo(lo_), hi(hi_) {}
if (nthreads == 1) operator bool() const { return hi>lo; }
{ f(); return; } };
auto & pool = get_pool();
latch counter(nthreads);
exception_ptr ex;
mutex ex_mut;
for (size_t i=0; i<nthreads; ++i)
{
pool.submit(
[&f, &counter, &ex, &ex_mut, i, nthreads] {
thread_id = i;
num_threads_ = nthreads;
try { f(); }
catch (...)
{
lock_guard<mutex> lock(ex_mut);
ex = current_exception();
}
counter.count_down();
});
}
counter.wait();
if (ex)
rethrow_exception(ex);
}
class Scheduler class Distribution
{ {
private: private:
size_t nthreads_; size_t nthreads_;
...@@ -239,24 +211,19 @@ class Scheduler ...@@ -239,24 +211,19 @@ class Scheduler
typedef enum { SINGLE, STATIC, DYNAMIC } SchedMode; typedef enum { SINGLE, STATIC, DYNAMIC } SchedMode;
SchedMode mode; SchedMode mode;
bool single_done; bool single_done;
struct Range
{ template <typename Func> void thread_map(Func f);
size_t lo, hi;
Range() : lo(0), hi(0) {}
Range(size_t lo_, size_t hi_) : lo(lo_), hi(hi_) {}
operator bool() const { return hi>lo; }
};
public: public:
size_t nthreads() const { return nthreads_; } size_t nthreads() const { return nthreads_; }
mutex &mut() { return mut_; }
template<typename Func> void execSingle(size_t nwork, Func f) template<typename Func> void execSingle(size_t nwork, Func f)
{ {
mode = SINGLE; mode = SINGLE;
single_done = false; single_done = false;
nwork_ = nwork; nwork_ = nwork;
f(*this); nthreads_ = 1;
thread_map(move(f));
} }
template<typename Func> void execStatic(size_t nwork, template<typename Func> void execStatic(size_t nwork,
size_t nthreads, size_t chunksize, Func f) size_t nthreads, size_t chunksize, Func f)
...@@ -270,7 +237,7 @@ class Scheduler ...@@ -270,7 +237,7 @@ class Scheduler
nextstart.resize(nthreads_); nextstart.resize(nthreads_);
for (size_t i=0; i<nextstart.size(); ++i) for (size_t i=0; i<nextstart.size(); ++i)
nextstart[i] = i*chunksize_; nextstart[i] = i*chunksize_;
thread_map(nthreads_, [&]() {f(*this);}); thread_map(move(f));
} }
template<typename Func> void execDynamic(size_t nwork, template<typename Func> void execDynamic(size_t nwork,
size_t nthreads, size_t chunksize_min, double fact_max, Func f) size_t nthreads, size_t chunksize_min, double fact_max, Func f)
...@@ -283,9 +250,17 @@ class Scheduler ...@@ -283,9 +250,17 @@ class Scheduler
return execStatic(nwork, nthreads, 0, move(f)); return execStatic(nwork, nthreads, 0, move(f));
fact_max_ = fact_max; fact_max_ = fact_max;
cur_ = 0; cur_ = 0;
thread_map(nthreads_, [&]() {f(*this);}); thread_map(move(f));
} }
Range getNext() template<typename Func> void execParallel(size_t nthreads, Func f)
{
mode = STATIC;
nthreads_ = (nthreads==0) ? max_threads_ : nthreads;
nwork_ = nthreads_;
chunksize_ = 1;
thread_map(move(f));
}
Range getNext(size_t thread_id)
{ {
switch (mode) switch (mode)
{ {
...@@ -320,38 +295,64 @@ class Scheduler ...@@ -320,38 +295,64 @@ class Scheduler
} }
}; };
template<typename Func> void execParallel(size_t nthreads, Func f) class Scheduler
{ {
nthreads = (nthreads==0) ? max_threads_ : nthreads; private:
thread_map(nthreads, move(f)); Distribution &dist_;
size_t ithread_;
public:
Scheduler(Distribution &dist, size_t ithread)
: dist_(dist), ithread_(ithread) {}
size_t num_threads() const { return dist_.nthreads(); }
size_t thread_num() const { return ithread_; }
Range getNext() { return dist_.getNext(ithread_); }
};
template <typename Func> void Distribution::thread_map(Func f)
{
auto & pool = get_pool();
latch counter(nthreads_);
exception_ptr ex;
mutex ex_mut;
for (size_t i=0; i<nthreads_; ++i)
{
pool.submit(
[this, &f, i, &counter, &ex, &ex_mut] {
try
{
Scheduler sched(*this, i);
f(sched);
}
catch (...)
{
lock_guard<mutex> lock(ex_mut);
ex = current_exception();
}
counter.count_down();
});
}
counter.wait();
if (ex)
rethrow_exception(ex);
} }
#else #else
constexpr size_t thread_id = 0;
constexpr size_t num_threads_ = 1;
constexpr size_t max_threads_ = 1; constexpr size_t max_threads_ = 1;
class Scheduler class Sched0
{ {
private: private:
size_t nwork_; size_t nwork_;
struct Range
{
size_t lo, hi;
Range() : lo(0), hi(0) {}
Range(size_t lo_, size_t hi_) : lo(lo_), hi(hi_) {}
operator bool() const { return hi>lo; }
};
public: public:
size_t nthreads() const { return 1; } size_t nthreads() const { return 1; }
// mutex &mut() { return mut_; }
template<typename Func> void execSingle(size_t nwork, Func f) template<typename Func> void execSingle(size_t nwork, Func f)
{ {
nwork_ = nwork; nwork_ = nwork;
f(*this); f(Scheduler(*this,0));
} }
template<typename Func> void execStatic(size_t nwork, template<typename Func> void execStatic(size_t nwork,
size_t /*nthreads*/, size_t /*chunksize*/, Func f) size_t /*nthreads*/, size_t /*chunksize*/, Func f)
...@@ -362,8 +363,9 @@ class Scheduler ...@@ -362,8 +363,9 @@ class Scheduler
{ execSingle(nwork, move(f)); } { execSingle(nwork, move(f)); }
Range getNext() Range getNext()
{ {
return Range(0, nwork_); Range res(0, nwork_);
nwork_=0; nwork_=0;
return res;
} }
}; };
...@@ -372,36 +374,43 @@ template<typename Func> void execParallel(size_t /*nthreads*/, Func f) ...@@ -372,36 +374,43 @@ template<typename Func> void execParallel(size_t /*nthreads*/, Func f)
#endif #endif
namespace {
template<typename Func> void execSingle(size_t nwork, Func f) template<typename Func> void execSingle(size_t nwork, Func f)
{ {
Scheduler sched; Distribution dist;
sched.execSingle(nwork, move(f)); dist.execSingle(nwork, move(f));
} }
template<typename Func> void execStatic(size_t nwork, template<typename Func> void execStatic(size_t nwork,
size_t nthreads, size_t chunksize, Func f) size_t nthreads, size_t chunksize, Func f)
{ {
Scheduler sched; Distribution dist;
sched.execStatic(nwork, nthreads, chunksize, move(f)); dist.execStatic(nwork, nthreads, chunksize, move(f));
} }
template<typename Func> void execDynamic(size_t nwork, template<typename Func> void execDynamic(size_t nwork,
size_t nthreads, size_t chunksize_min, Func f) size_t nthreads, size_t chunksize_min, Func f)
{ {
Scheduler sched; Distribution dist;
sched.execDynamic(nwork, nthreads, chunksize_min, 0., move(f)); dist.execDynamic(nwork, nthreads, chunksize_min, 0., move(f));
} }
template<typename Func> void execGuided(size_t nwork, template<typename Func> void execGuided(size_t nwork,
size_t nthreads, size_t chunksize_min, double fact_max, Func f) size_t nthreads, size_t chunksize_min, double fact_max, Func f)
{ {
Scheduler sched; Distribution dist;
sched.execDynamic(nwork, nthreads, chunksize_min, fact_max, move(f)); dist.execDynamic(nwork, nthreads, chunksize_min, fact_max, move(f));
} }
size_t num_threads() template<typename Func> static void execParallel(size_t nthreads, Func f)
{ return num_threads_; } {
size_t thread_num() Distribution dist;
{ return thread_id; } dist.execParallel(nthreads, move(f));
size_t max_threads() }
inline size_t max_threads()
{ return max_threads_; } { return max_threads_; }
}
} // end of namespace detail } // end of namespace detail
using detail_threading::Scheduler; using detail_threading::Scheduler;
...@@ -409,8 +418,6 @@ using detail_threading::execSingle; ...@@ -409,8 +418,6 @@ using detail_threading::execSingle;
using detail_threading::execStatic; using detail_threading::execStatic;
using detail_threading::execDynamic; using detail_threading::execDynamic;
using detail_threading::execGuided; using detail_threading::execGuided;
using detail_threading::num_threads;
using detail_threading::thread_num;
using detail_threading::max_threads; using detail_threading::max_threads;
using detail_threading::execParallel; using detail_threading::execParallel;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment