Commit 561b9d72 authored by Martin Reinecke's avatar Martin Reinecke

smal tweak

parent da64e8d9
......@@ -1595,17 +1595,26 @@ void fillIdx(const Baselines &baselines,
constexpr int side=1<<logsquare;
size_t nbu = (gconf.Nu()+1+side-1) >> logsquare,
nbv = (gconf.Nv()+1+side-1) >> logsquare;
vector<idx_t> acc(nbu*nbv+1, 0);
vector<idx_t> tmp(nrow*(chend-chbegin),~idx_t(0));
vector<vector<idx_t>> acc;
vector<idx_t> tmp(nrow*(chend-chbegin));
#pragma omp parallel num_threads(gconf.Nthreads())
{
auto nthr = my_num_threads();
auto id = my_thread_num();
#pragma omp single
acc.resize(nthr);
#pragma omp barrier
size_t lo, hi;
calc_share(my_num_threads(), my_thread_num(), nrow, lo, hi);
vector<idx_t> lacc(nbu*nbv+1, 0);
calc_share(nthr, id, nrow, lo, hi);
vector<idx_t> &lacc(acc[id]);
lacc.resize(nbu*nbv+1, 0);
for (idx_t irow=lo, idx=lo*(chend-chbegin); irow<hi; ++irow)
for (int ichan=chbegin; ichan<chend; ++ichan)
for (int ichan=chbegin; ichan<chend; ++ichan, ++idx)
{
tmp[idx] = ~idx_t(0);
if (!flags(irow, ichan))
{
auto uvw = baselines.effectiveCoord(RowChan{irow,idx_t(ichan)});
......@@ -1618,28 +1627,32 @@ void fillIdx(const Baselines &baselines,
iu0 = (iu0+nsafe)>>logsquare;
iv0 = (iv0+nsafe)>>logsquare;
++lacc[nbv*iu0 + iv0 + 1];
tmp[idx++] = nbv*iu0 + iv0;
tmp[idx] = nbv*iu0 + iv0;
}
}
}
#pragma omp barrier
#pragma omp critical(xyz)
for (size_t i=0; i<acc.size(); ++i)
acc[i]+=lacc[i];
size_t lo2, hi2;
calc_share(nthr, id, nbu*nbv, lo2, hi2);
for (size_t i=lo2+1; i<hi2+1; ++i)
{
idx_t sum=0;
for (int j=0; j<nthr; ++j)
sum += acc[j][i];
acc[0][i]=sum;
}
}
for (size_t i=1; i<acc.size(); ++i)
acc[i] += acc[i-1];
myassert(res.shape(0)==acc.back(), "array size mismatch");
for (idx_t irow=0, idx=0; irow<nrow; ++irow)
for (int ichan=chbegin; ichan<chend; ++ichan)
if (!flags(irow, ichan))
{
while (tmp[idx]==idx_t(~0)) ++idx;
auto w = abs(baselines.effectiveCoord(RowChan{irow,idx_t(ichan)}).w);
if ((w>=wmin) && (w<wmax))
res[acc[tmp[idx++]]++] = baselines.getIdx(irow, ichan);
}
auto &acc0(acc[0]);
for (size_t i=1; i<acc0.size(); ++i)
acc0[i] += acc0[i-1];
myassert(res.shape(0)==acc0.back(), "array size mismatch");
for (size_t irow=0, idx=0; irow<nrow; ++irow)
for (int ichan=chbegin; ichan<chend; ++ichan, ++idx)
if (tmp[idx]!=(~idx_t(0)))
res[acc0[tmp[idx]]++] = baselines.getIdx(irow, ichan);
}
template<typename T> vector<idx_t> getWgtIndices(const Baselines &baselines,
......@@ -1653,17 +1666,23 @@ template<typename T> vector<idx_t> getWgtIndices(const Baselines &baselines,
constexpr int side=1<<logsquare;
size_t nbu = (gconf.Nu()+1+side-1) >> logsquare,
nbv = (gconf.Nv()+1+side-1) >> logsquare;
vector<idx_t> acc(nbu*nbv+1, 0);
vector<idx_t> tmp(nrow*nchan,~idx_t(0));
vector<vector<idx_t>> acc;
vector<idx_t> tmp(nrow*nchan);
#pragma omp parallel num_threads(gconf.Nthreads())
{
size_t lo, hi;
calc_share(my_num_threads(), my_thread_num(), nrow, lo, hi);
vector<idx_t> lacc(nbu*nbv+1, 0);
auto nthr = my_num_threads();
auto id = my_thread_num();
#pragma omp single
acc.resize(nthr);
#pragma omp barrier
size_t lo, hi;
calc_share(nthr, id, nrow, lo, hi);
vector<idx_t> &lacc(acc[id]);
lacc.resize(nbu*nbv+1, 0);
for (idx_t irow=lo, idx=lo*nchan; irow<hi; ++irow)
for (idx_t ichan=0; ichan<nchan; ++ichan)
for (idx_t ichan=0; ichan<nchan; ++ichan, ++idx)
if ((!have_wgt) || (wgt(irow,ichan)!=0))
{
auto uvw = baselines.effectiveCoord(RowChan{irow,idx_t(ichan)});
......@@ -1674,26 +1693,32 @@ template<typename T> vector<idx_t> getWgtIndices(const Baselines &baselines,
iu0 = (iu0+nsafe)>>logsquare;
iv0 = (iv0+nsafe)>>logsquare;
++lacc[nbv*iu0 + iv0 + 1];
tmp[idx++] = nbv*iu0 + iv0;
tmp[idx] = nbv*iu0 + iv0;
}
else
tmp[idx] = ~idx_t(0);
#pragma omp barrier
#pragma omp critical(xyz)
for (size_t i=0; i<acc.size(); ++i)
acc[i]+=lacc[i];
size_t lo2, hi2;
calc_share(nthr, id, nbu*nbv, lo2, hi2);
for (size_t i=lo2+1; i<hi2+1; ++i)
{
idx_t sum=0;
for (int j=0; j<nthr; ++j)
sum += acc[j][i];
acc[0][i]=sum;
}
}
for (size_t i=1; i<acc.size(); ++i)
acc[i] += acc[i-1];
auto &acc0(acc[0]);
for (size_t i=1; i<acc0.size(); ++i)
acc0[i] += acc0[i-1];
vector<idx_t> res(acc.back());
vector<idx_t> res(acc0.back());
for (size_t irow=0, idx=0; irow<nrow; ++irow)
for (size_t ichan=0; ichan<nchan; ++ichan)
{
while (tmp[idx]==idx_t(~0)) ++idx;
if ((!have_wgt) || (wgt(irow,ichan)!=0))
res[acc[tmp[idx++]]++] = baselines.getIdx(irow, ichan);
}
for (size_t ichan=0; ichan<nchan; ++ichan, ++idx)
if (tmp[idx]!=(~idx_t(0)))
res[acc0[tmp[idx]]++] = baselines.getIdx(irow, ichan);
return res;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment