Commit 6b4f68f3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

experiment with OpenMP locks

parent 2812ffad
...@@ -29,6 +29,10 @@ ...@@ -29,6 +29,10 @@
#include <vector> #include <vector>
#include <array> #include <array>
#ifdef _OPENMP
#include <omp.h>
#endif
#include "pocketfft_hdronly.h" #include "pocketfft_hdronly.h"
#if defined(__GNUC__) #if defined(__GNUC__)
...@@ -715,6 +719,32 @@ class GridderConfig ...@@ -715,6 +719,32 @@ class GridderConfig
constexpr int logsquare=4; constexpr int logsquare=4;
#ifdef _OPENMP
class Lock
{
private:
omp_lock_t lck;
Lock(const Lock &) = delete;
Lock& operator=(const Lock &) = delete;
public:
Lock() { omp_init_lock(&lck); }
~Lock() { omp_destroy_lock(&lck); }
void lock() { omp_set_lock(&lck); }
void unlock() { omp_unset_lock(&lck); }
};
#else
class Lock
{
public:
Lock() {}
~Lock() {}
void lock() {}
void unlock {}
};
#endif
template<typename T, typename T2=complex<T>> class Helper template<typename T, typename T2=complex<T>> class Helper
{ {
private: private:
...@@ -732,6 +762,7 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -732,6 +762,7 @@ template<typename T, typename T2=complex<T>> class Helper
T w0, xdw; T w0, xdw;
size_t nexp; size_t nexp;
size_t nvecs; size_t nvecs;
vector<Lock> &locks;
void dump() const void dump() const
{ {
...@@ -744,11 +775,13 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -744,11 +775,13 @@ template<typename T, typename T2=complex<T>> class Helper
for (int iu=0; iu<su; ++iu) for (int iu=0; iu<su; ++iu)
{ {
int idxv = idxv0; int idxv = idxv0;
locks[idxu].lock();
for (int iv=0; iv<sv; ++iv) for (int iv=0; iv<sv; ++iv)
{ {
grid_w[idxu*nv + idxv] += wbuf[iu*sv + iv]; grid_w[idxu*nv + idxv] += wbuf[iu*sv + iv];
if (++idxv>=nv) idxv=0; if (++idxv>=nv) idxv=0;
} }
locks[idxu].unlock();
if (++idxu>=nu) idxu=0; if (++idxu>=nu) idxu=0;
} }
} }
...@@ -776,7 +809,7 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -776,7 +809,7 @@ template<typename T, typename T2=complex<T>> class Helper
T kernel[64] ALIGNED(64); T kernel[64] ALIGNED(64);
Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_, Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_,
T w0_=-1, T dw_=-1) vector<Lock> &locks_, T w0_=-1, T dw_=-1)
: gconf(gconf_), nu(gconf.Nu()), nv(gconf.Nv()), nsafe(gconf.Nsafe()), : gconf(gconf_), nu(gconf.Nu()), nv(gconf.Nv()), nsafe(gconf.Nsafe()),
supp(gconf.Supp()), beta(gconf.Beta()), grid_r(grid_r_), supp(gconf.Supp()), beta(gconf.Beta()), grid_r(grid_r_),
grid_w(grid_w_), su(2*nsafe+(1<<logsquare)), sv(2*nsafe+(1<<logsquare)), grid_w(grid_w_), su(2*nsafe+(1<<logsquare)), sv(2*nsafe+(1<<logsquare)),
...@@ -787,7 +820,8 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -787,7 +820,8 @@ template<typename T, typename T2=complex<T>> class Helper
w0(w0_), w0(w0_),
xdw(T(1)/dw_), xdw(T(1)/dw_),
nexp(2*supp + do_w_gridding), nexp(2*supp + do_w_gridding),
nvecs(VLEN<T>::val*((nexp+VLEN<T>::val-1)/VLEN<T>::val)) nvecs(VLEN<T>::val*((nexp+VLEN<T>::val-1)/VLEN<T>::val)),
locks(locks_)
{} {}
~Helper() { if (grid_w) dump(); } ~Helper() { if (grid_w) dump(); }
...@@ -921,10 +955,11 @@ template<typename T, typename Serv> void x2grid_c ...@@ -921,10 +955,11 @@ template<typename T, typename Serv> void x2grid_c
size_t supp = gconf.Supp(); size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads(); size_t nthreads = gconf.Nthreads();
bool do_w_gridding = dw>0; bool do_w_gridding = dw>0;
vector<Lock> locks(gconf.Nu());
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
{ {
Helper<T> hlp(gconf, nullptr, grid.data(), w0, dw); Helper<T> hlp(gconf, nullptr, grid.data(), locks, w0, dw);
int jump = hlp.lineJump(); int jump = hlp.lineJump();
const T * RESTRICT ku = hlp.kernel; const T * RESTRICT ku = hlp.kernel;
const T * RESTRICT kv = hlp.kernel+supp; const T * RESTRICT kv = hlp.kernel+supp;
...@@ -981,11 +1016,12 @@ template<typename T, typename Serv> void grid2x_c ...@@ -981,11 +1016,12 @@ template<typename T, typename Serv> void grid2x_c
size_t supp = gconf.Supp(); size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads(); size_t nthreads = gconf.Nthreads();
bool do_w_gridding = dw>0; bool do_w_gridding = dw>0;
vector<Lock> locks(gconf.Nu());
// Loop over sampling points // Loop over sampling points
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
{ {
Helper<T> hlp(gconf, grid.data(), nullptr, w0, dw); Helper<T> hlp(gconf, grid.data(), nullptr, locks, w0, dw);
int jump = hlp.lineJump(); int jump = hlp.lineJump();
const T * RESTRICT ku = hlp.kernel; const T * RESTRICT ku = hlp.kernel;
const T * RESTRICT kv = hlp.kernel+supp; const T * RESTRICT kv = hlp.kernel+supp;
...@@ -1044,11 +1080,12 @@ template<typename T> void apply_holo ...@@ -1044,11 +1080,12 @@ template<typename T> void apply_holo
checkShape(ogrid.shape(), grid.shape()); checkShape(ogrid.shape(), grid.shape());
ogrid.fill(0); ogrid.fill(0);
size_t supp = gconf.Supp(); size_t supp = gconf.Supp();
vector<Lock> locks(gconf.Nu());
// Loop over sampling points // Loop over sampling points
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
{ {
Helper<T> hlp(gconf, grid.data(), ogrid.data()); Helper<T> hlp(gconf, grid.data(), ogrid.data(), locks);
int jump = hlp.lineJump(); int jump = hlp.lineJump();
const T * RESTRICT ku = hlp.kernel; const T * RESTRICT ku = hlp.kernel;
const T * RESTRICT kv = hlp.kernel+supp; const T * RESTRICT kv = hlp.kernel+supp;
...@@ -1114,6 +1151,7 @@ template<typename T> void get_correlations ...@@ -1114,6 +1151,7 @@ template<typename T> void get_correlations
myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp"); myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp");
size_t nthreads = gconf.Nthreads(); size_t nthreads = gconf.Nthreads();
ogrid.fill(0); ogrid.fill(0);
vector<Lock> locks(gconf.Nu());
size_t u0, u1, v0, v1; size_t u0, u1, v0, v1;
if (du>=0) if (du>=0)
...@@ -1128,7 +1166,7 @@ template<typename T> void get_correlations ...@@ -1128,7 +1166,7 @@ template<typename T> void get_correlations
// Loop over sampling points // Loop over sampling points
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
{ {
Helper<T,T> hlp(gconf, nullptr, ogrid.data()); Helper<T,T> hlp(gconf, nullptr, ogrid.data(), locks);
int jump = hlp.lineJump(); int jump = hlp.lineJump();
const T * RESTRICT ku = hlp.kernel; const T * RESTRICT ku = hlp.kernel;
const T * RESTRICT kv = hlp.kernel+supp; const T * RESTRICT kv = hlp.kernel+supp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment