diff --git a/cpp/field.cpp b/cpp/field.cpp index 72124ba4c290143d4588795b3359b1d284b7959d..69d930497e04229fea138aa4ebd0de65b2373e75 100644 --- a/cpp/field.cpp +++ b/cpp/field.cpp @@ -2289,10 +2289,7 @@ field<rnumber, be, fc> &field<rnumber, be, fc>::operator=( assert(this->get_nx() == src.get_nx()); assert(this->get_ny() == src.get_ny()); assert(this->get_nz() == src.get_nz()); - this->real_space_representation = true; - std::copy(src.data, - src.data + this->rmemlayout->local_size, - this->data); + this->operator=(src.data); } else { @@ -2302,9 +2299,7 @@ field<rnumber, be, fc> &field<rnumber, be, fc>::operator=( this->get_ny() == src.get_ny() && this->get_nz() == src.get_nz()) { - std::copy(src.data, - src.data + this->rmemlayout->local_size, - this->data); + this->operator=(src.get_cdata()); } // complicated resize else diff --git a/cpp/field.hpp b/cpp/field.hpp index b8775f219caabbca80edd6289957b8aa38001338..0ac502ebe6e989316fb167c7279d2dfff115aeb5 100644 --- a/cpp/field.hpp +++ b/cpp/field.hpp @@ -227,31 +227,86 @@ class field inline field<rnumber, be, fc>& operator=(const typename fftw_interface<rnumber>::complex *__restrict__ source) { - std::copy((rnumber*)source, - (rnumber*)(source + this->clayout->local_size), - this->data); + // use CLOOP pattern, because we want the array to be arranged in memory + // for optimal access by FFTW + #pragma omp parallel + { + const hsize_t start = OmpUtils::ForIntervalStart(this->clayout->subsizes[1]); + const hsize_t end = OmpUtils::ForIntervalEnd(this->clayout->subsizes[1]); + + for (hsize_t yindex = 0; yindex < this->clayout->subsizes[0]; yindex++){ + for (hsize_t zindex = start; zindex < end; zindex++){ + const ptrdiff_t cindex = ( + yindex*this->clayout->subsizes[1]*this->clayout->subsizes[2] + + zindex*this->clayout->subsizes[2]); + for (hsize_t xindex = 0; xindex < this->clayout->subsizes[2]; xindex++) + { + std::copy((rnumber*)(source + cindex*ncomp(fc)), + (rnumber*)(source + (cindex+1)*ncomp(fc)), + this->data+(cindex*ncomp(fc))*2); + } + } + } + } this->real_space_representation = false; return *this; } inline field<rnumber, be, fc>& operator=(const rnumber *__restrict__ source) { - std::copy(source, - source + this->rmemlayout->local_size, - this->data); + // use RLOOP, such that memory caching per thread stuff is not messed up + this->RLOOP( + [&](const ptrdiff_t rindex, + const ptrdiff_t xindex, + const ptrdiff_t yindex, + const ptrdiff_t zindex) + { + std::copy(source + rindex*ncomp(fc), + source + (rindex+1)*ncomp(fc), + this->data + rindex*ncomp(fc)); + }); this->real_space_representation = true; return *this; } inline field<rnumber, be, fc>& operator=(const rnumber value) { - #pragma omp parallel + if (this->real_space_representation || true) + // use RLOOP, such that memory caching per thread stuff is not messed up + this->RLOOP( + [&](const ptrdiff_t rindex, + const ptrdiff_t xindex, + const ptrdiff_t yindex, + const ptrdiff_t zindex) + { + std::fill_n(this->data + rindex*ncomp(fc), + ncomp(fc), + value); + }); + else { - const hsize_t start = OmpUtils::ForIntervalStart(this->rmemlayout->local_size); - const hsize_t end = OmpUtils::ForIntervalEnd(this->rmemlayout->local_size); - std::fill_n(this->data + start, - end - start, - value); + // use CLOOP, such that memory caching per thread stuff is not messed up + #pragma omp parallel + { + const hsize_t start = OmpUtils::ForIntervalStart(this->clayout->subsizes[1]); + const hsize_t end = OmpUtils::ForIntervalEnd(this->clayout->subsizes[1]); + + for (hsize_t yindex = 0; yindex < this->clayout->subsizes[0]; yindex++){ + for (hsize_t zindex = start; zindex < end; zindex++){ + const ptrdiff_t cindex = ( + yindex*this->clayout->subsizes[1]*this->clayout->subsizes[2] + + zindex*this->clayout->subsizes[2]); + for (hsize_t xindex = 0; xindex < this->clayout->subsizes[2]; xindex++) + { + for (unsigned int cc; cc <= ncomp(fc); cc++) + { + *(this->data + 2*(cindex*ncomp(fc) + cc)) = value; + *(this->data + 2*(cindex*ncomp(fc) + cc)+1) = 0.0; + } + } + } + } + } } return *this; } diff --git a/tests/DNS/test_scaling.py b/tests/DNS/test_scaling.py index b6e025b5dc31fb6a5c0ee6336e7b038d4ec0ed21..d1917b9b0512bf0bf03ee42869c376dc6d766d97 100644 --- a/tests/DNS/test_scaling.py +++ b/tests/DNS/test_scaling.py @@ -88,6 +88,7 @@ def get_DNS_parameters( DNS_parameters += [ '--tracers0_neighbours', '{0}'.format(nneighbours), '--tracers0_smoothness', '{0}'.format(smoothness), + '--niter_part', '6', '--cpp_random_particles', '2'] if no_submit: DNS_parameters += ['--no-submit']