Commit 9d324324 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

get_diagonal->get_correlations; dx,dy->du,dv; allow negative values for du,dv

parent 9097edbb
......@@ -23,6 +23,8 @@
#include <pybind11/numpy.h>
#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#include "pocketfft_hdronly.h"
......@@ -74,7 +76,7 @@ nthreads: int
)""";
void set_nthreads(size_t nthreads_)
{
if (nthreads_<1) throw runtime_error("nthreads must be >= 1");
myassert(nthreads_>=1, "nthreads must be >= 1");
nthreads = nthreads_;
}
......@@ -140,7 +142,7 @@ void legendre_prep(int n, vector<double> &x, vector<double> &w)
if (dobreak) break;
if (abs(dx)<=eps) dobreak=1;
if (++j>=100) throw runtime_error("convergence problem");
myassert(++j<100, "convergence problem");
}
x[m-i] = x0;
......@@ -1273,14 +1275,14 @@ template<typename T> pyarr_c<complex<T>> apply_holo(
return res;
}
template<typename T> pyarr_c<T> get_diagonal(
template<typename T> pyarr_c<T> get_correlations(
const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const pyarr<uint32_t> &idx_, size_t dx, size_t dy)
const pyarr<uint32_t> &idx_, int du, int dv)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
size_t w = gconf.W();
if (dx>=w) throw runtime_error("dx must be smaller than W");
if (dy>=w) throw runtime_error("dy must be smaller than W");
myassert(size_t(abs(du))<w, "|du| must be smaller than W");
myassert(size_t(abs(dv))<w, "|dv| must be smaller than W");
checkArray(idx_, "idx", {0});
size_t nvis = size_t(idx_.shape(0));
auto idx = idx_.template unchecked<1>();
......@@ -1292,6 +1294,16 @@ template<typename T> pyarr_c<T> get_diagonal(
T beta = gconf.Beta();
for (size_t i=0; i<nu*nv; ++i) ogrid[i] = 0.;
size_t u0, u1, v0, v1;
if (du>=0)
{ u0=0; u1=w-du; }
else
{ u0=-du; u1=w; }
if (dv>=0)
{ v0=0; v1=w-dv; }
else
{ v0=-dv; v1=w; }
// Loop over sampling points
#pragma omp parallel num_threads(nthreads)
{
......@@ -1306,12 +1318,12 @@ template<typename T> pyarr_c<T> get_diagonal(
{
UVW<T> coord = baselines.effectiveCoord(idx(ipart));
hlp.prep(coord.u, coord.v);
auto * RESTRICT wptr = hlp.p0w;
for (size_t cu=0; cu<w-dx; ++cu)
auto * RESTRICT wptr = hlp.p0w + u0*jump;
for (size_t cu=u0; cu<u1; ++cu)
{
auto f1=ku[cu]*ku[cu+dx]*emb*emb;
for (size_t cv=0; cv<w-dy; ++cv)
wptr[cv] += f1*kv[cv]*kv[cv+dy];
auto f1=ku[cu]*ku[cu+du]*emb*emb;
for (size_t cv=v0; cv<v1; ++cv)
wptr[cv] += f1*kv[cv]*kv[cv+dv];
wptr += jump;
}
}
......@@ -1446,8 +1458,7 @@ PYBIND11_MODULE(nifty_gridder, m)
},
// __setstate__
[](py::tuple t) {
if(t.size() != 5)
{ throw std::runtime_error("Invalid state"); }
myassert(t.size()==5,"Invalid state");
// Reconstruct from tuple
return GridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(),
......@@ -1484,8 +1495,8 @@ PYBIND11_MODULE(nifty_gridder, m)
"idx"_a, "grid"_a, "wgt"_a, "ms_in"_a=None);
m.def("apply_holo",&apply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a,
"grid"_a);
m.def("get_diagonal", &get_diagonal<double>, "baselines"_a, "gconf"_a,
"idx"_a, "dx"_a, "dy"_a);
m.def("get_correlations", &get_correlations<double>, "baselines"_a, "gconf"_a,
"idx"_a, "du"_a, "dv"_a);
m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment