From b80c8722e1b02240f358ecbed8bb2135d5165c7e Mon Sep 17 00:00:00 2001
From: Tobias Winchen <tobias.winchen@rwth-aachen.de>
Date: Thu, 3 Feb 2022 09:17:29 +0000
Subject: [PATCH] Added possiibility to skip calculation of one gate.

---
 .../effelsberg/edd/GatedSpectrometer.cuh      |  7 +-
 .../edd/detail/GatedSpectrometer.cu           | 98 +++++++++++--------
 .../edd/src/GatedSpectrometer_cli.cu          |  9 +-
 3 files changed, 68 insertions(+), 46 deletions(-)

diff --git a/psrdada_cpp/effelsberg/edd/GatedSpectrometer.cuh b/psrdada_cpp/effelsberg/edd/GatedSpectrometer.cuh
index f3502675..d6cf309f 100644
--- a/psrdada_cpp/effelsberg/edd/GatedSpectrometer.cuh
+++ b/psrdada_cpp/effelsberg/edd/GatedSpectrometer.cuh
@@ -14,7 +14,7 @@
 
 #include "cublas_v2.h"
 
-
+#include <bitset>
 #include <iostream>
 #include <iomanip>
 #include <cstring>
@@ -295,7 +295,9 @@ struct GatedSpectrometerInputParameters
         size_t fft_length;
         size_t naccumulate;
         unsigned int nbits;
-    //    size_t disable_gate;
+        std::bitset<2> active_gates;
+
+        GatedSpectrometerInputParameters(): selectedSideChannel(0), speadHeapSize(0), nSideChannels(0), selectedBit(0), fft_length(0), naccumulate(0), nbits(0), active_gates(3) {}; // Default both gates active
     };
 
 
@@ -371,6 +373,7 @@ private:
   std::size_t _selectedBit;
   std::size_t _batch;
   std::size_t _nsamps_per_heap;
+  std::bitset<2> _active_gates;
 
   HandlerType &_handler;
   cufftHandle _fft_plan;
diff --git a/psrdada_cpp/effelsberg/edd/detail/GatedSpectrometer.cu b/psrdada_cpp/effelsberg/edd/detail/GatedSpectrometer.cu
index 7daa377a..7a33beba 100644
--- a/psrdada_cpp/effelsberg/edd/detail/GatedSpectrometer.cu
+++ b/psrdada_cpp/effelsberg/edd/detail/GatedSpectrometer.cu
@@ -66,7 +66,7 @@ __global__ void update_baselines(float*  __restrict__ baseLineG0,
 template <class HandlerType, class InputType, class OutputType>
 GatedSpectrometer<HandlerType, InputType, OutputType>::GatedSpectrometer(
     const GatedSpectrometerInputParameters &ip, HandlerType &handler) : _dadaBufferLayout(ip.dadaBufferLayout),
-    _selectedSideChannel(ip.selectedSideChannel), _selectedBit(ip.selectedBit),
+    _selectedSideChannel(ip.selectedSideChannel), _selectedBit(ip.selectedBit), _active_gates(ip.active_gates),
     _fft_length(ip.fft_length), _naccumulate(ip.naccumulate),
     _handler(handler), _fft_plan(0), _call_count(0), _nsamps_per_heap(4096)
 {
@@ -263,23 +263,29 @@ void GatedSpectrometer<HandlerType, InputType, OutputType>::gated_fft(
         _noOfBitSetsIn_G0.size()
             );
 
-  BOOST_LOG_TRIVIAL(debug) << "Performing FFT 1";
-  BOOST_LOG_TRIVIAL(debug) << "Accessing unpacked voltage";
-  UnpackedVoltageType *_unpacked_voltage_ptr =
-      thrust::raw_pointer_cast(_unpacked_voltage_G0.data());
-  BOOST_LOG_TRIVIAL(debug) << "Accessing channelized voltage";
-  ChannelisedVoltageType *_channelised_voltage_ptr =
-      thrust::raw_pointer_cast(data._channelised_voltage_G0.data());
-
-  CUFFT_ERROR_CHECK(cufftExecR2C(_fft_plan, (cufftReal *)_unpacked_voltage_ptr,
-                                 (cufftComplex *)_channelised_voltage_ptr));
+  if (_active_gates[0])
+  {
 
-  BOOST_LOG_TRIVIAL(debug) << "Performing FFT 2";
-  _unpacked_voltage_ptr = thrust::raw_pointer_cast(_unpacked_voltage_G1.data());
-  _channelised_voltage_ptr = thrust::raw_pointer_cast(data._channelised_voltage_G1.data());
-  CUFFT_ERROR_CHECK(cufftExecR2C(_fft_plan, (cufftReal *)_unpacked_voltage_ptr,
-                                 (cufftComplex *)_channelised_voltage_ptr));
+    BOOST_LOG_TRIVIAL(debug) << "Performing FFT 1";
+    BOOST_LOG_TRIVIAL(debug) << "Accessing unpacked voltage";
+    UnpackedVoltageType *_unpacked_voltage_ptr = thrust::raw_pointer_cast(_unpacked_voltage_G0.data());
+    BOOST_LOG_TRIVIAL(debug) << "Accessing channelized voltage";
+    ChannelisedVoltageType *_channelised_voltage_ptr =
+        thrust::raw_pointer_cast(data._channelised_voltage_G0.data());
 
+    CUFFT_ERROR_CHECK(cufftExecR2C(_fft_plan, (cufftReal *)_unpacked_voltage_ptr,
+                                   (cufftComplex *)_channelised_voltage_ptr));
+  }
+  if (_active_gates[1])
+  {
+      BOOST_LOG_TRIVIAL(debug) << "Performing FFT 2";
+      BOOST_LOG_TRIVIAL(debug) << "Accessing unpacked voltage";
+      UnpackedVoltageType *_unpacked_voltage_ptr = thrust::raw_pointer_cast(_unpacked_voltage_G1.data());
+      BOOST_LOG_TRIVIAL(debug) << "Accessing channelized voltage";
+      ChannelisedVoltageType *_channelised_voltage_ptr = thrust::raw_pointer_cast(data._channelised_voltage_G1.data());
+      CUFFT_ERROR_CHECK(cufftExecR2C(_fft_plan, (cufftReal *)_unpacked_voltage_ptr,
+                                     (cufftComplex *)_channelised_voltage_ptr));
+  }
 //  CUDA_ERROR_CHECK(cudaStreamSynchronize(_proc_stream));
 //  BOOST_LOG_TRIVIAL(debug) << "Exit processing";
 } // process
@@ -364,21 +370,27 @@ void GatedSpectrometer<HandlerType, InputType, OutputType>::process(SinglePolari
 {
   gated_fft(*inputDataStream, outputDataStream->G0._noOfBitSets.a(), outputDataStream->G1._noOfBitSets.a());
 
-  kernels::detect_and_accumulate<IntegratedPowerType> <<<1024, 1024, 0, _proc_stream>>>(
-            thrust::raw_pointer_cast(inputDataStream->_channelised_voltage_G0.data()),
-            thrust::raw_pointer_cast(outputDataStream->G0.data.a().data()),
-            _nchans,
-            inputDataStream->_channelised_voltage_G0.size(),
-            _naccumulate / _nBlocks,
-            1, 0., 1, 0);
-
-  kernels::detect_and_accumulate<IntegratedPowerType> <<<1024, 1024, 0, _proc_stream>>>(
-            thrust::raw_pointer_cast(inputDataStream->_channelised_voltage_G1.data()),
-            thrust::raw_pointer_cast(outputDataStream->G1.data.a().data()),
-            _nchans,
-            inputDataStream->_channelised_voltage_G1.size(),
-            _naccumulate / _nBlocks,
-            1, 0., 1, 0);
+  if (_active_gates[0])
+  {
+    kernels::detect_and_accumulate<IntegratedPowerType> <<<1024, 1024, 0, _proc_stream>>>(
+              thrust::raw_pointer_cast(inputDataStream->_channelised_voltage_G0.data()),
+              thrust::raw_pointer_cast(outputDataStream->G0.data.a().data()),
+              _nchans,
+              inputDataStream->_channelised_voltage_G0.size(),
+              _naccumulate / _nBlocks,
+              1, 0., 1, 0);
+  }
+
+  if (_active_gates[1])
+  {
+    kernels::detect_and_accumulate<IntegratedPowerType> <<<1024, 1024, 0, _proc_stream>>>(
+              thrust::raw_pointer_cast(inputDataStream->_channelised_voltage_G1.data()),
+              thrust::raw_pointer_cast(outputDataStream->G1.data.a().data()),
+              _nchans,
+              inputDataStream->_channelised_voltage_G1.size(),
+              _naccumulate / _nBlocks,
+              1, 0., 1, 0);
+  }
 
     // count saturated samples
     for(size_t output_block_number = 0; output_block_number < outputDataStream->G0._noOfOverflowed.size(); output_block_number++)
@@ -429,6 +441,9 @@ void GatedSpectrometer<HandlerType, InputType, OutputType>::process(DualPolariza
       size_t input_offset = output_block_number * inputDataStream->polarization0._channelised_voltage_G0.size() / outputDataStream->G0._noOfBitSets.size();
       size_t output_offset = output_block_number * outputDataStream->G0.I.a().size() / outputDataStream->G0._noOfBitSets.size();
       BOOST_LOG_TRIVIAL(debug) << "Accumulating data for output block " << output_block_number << " with input offset " << input_offset << " and output_offset " << output_offset;
+
+    if (_active_gates[0])
+    {
       stokes_accumulate<<<1024, 1024, 0, _proc_stream>>>(
               thrust::raw_pointer_cast(inputDataStream->polarization0._channelised_voltage_G0.data() + input_offset),
               thrust::raw_pointer_cast(inputDataStream->polarization1._channelised_voltage_G0.data() + input_offset),
@@ -438,16 +453,19 @@ void GatedSpectrometer<HandlerType, InputType, OutputType>::process(DualPolariza
               thrust::raw_pointer_cast(outputDataStream->G0.V.a().data() + output_offset),
               _nchans, _naccumulate / _nBlocks
               );
+    }
 
-      stokes_accumulate<<<1024, 1024, 0, _proc_stream>>>(
-              thrust::raw_pointer_cast(inputDataStream->polarization0._channelised_voltage_G1.data() + input_offset),
-              thrust::raw_pointer_cast(inputDataStream->polarization1._channelised_voltage_G1.data() + input_offset),
-              thrust::raw_pointer_cast(outputDataStream->G1.I.a().data() + output_offset),
-              thrust::raw_pointer_cast(outputDataStream->G1.Q.a().data() + output_offset),
-              thrust::raw_pointer_cast(outputDataStream->G1.U.a().data() + output_offset),
-              thrust::raw_pointer_cast(outputDataStream->G1.V.a().data() + output_offset),
-              _nchans, _naccumulate / _nBlocks
-              );
+    if (_active_gates[1]){
+        stokes_accumulate<<<1024, 1024, 0, _proc_stream>>>(
+                thrust::raw_pointer_cast(inputDataStream->polarization0._channelised_voltage_G1.data() + input_offset),
+                thrust::raw_pointer_cast(inputDataStream->polarization1._channelised_voltage_G1.data() + input_offset),
+                thrust::raw_pointer_cast(outputDataStream->G1.I.a().data() + output_offset),
+                thrust::raw_pointer_cast(outputDataStream->G1.Q.a().data() + output_offset),
+                thrust::raw_pointer_cast(outputDataStream->G1.U.a().data() + output_offset),
+                thrust::raw_pointer_cast(outputDataStream->G1.V.a().data() + output_offset),
+                _nchans, _naccumulate / _nBlocks
+                );
+      }
 
         // count saturated samples
         outputDataStream->G0._noOfOverflowed.a().data()[output_block_number] = 0;
diff --git a/psrdada_cpp/effelsberg/edd/src/GatedSpectrometer_cli.cu b/psrdada_cpp/effelsberg/edd/src/GatedSpectrometer_cli.cu
index 22d31193..1279307b 100644
--- a/psrdada_cpp/effelsberg/edd/src/GatedSpectrometer_cli.cu
+++ b/psrdada_cpp/effelsberg/edd/src/GatedSpectrometer_cli.cu
@@ -177,10 +177,11 @@ int main(int argc, char **argv) {
     desc.add_options()("naccumulate,a",
                        po::value<size_t>(&ip.naccumulate)->required(),
                        "The number of samples to integrate in each channel");
-//    desc.add_options()("disable_gate,d",
-//                       po::value<size_t>(&ip.disable_gate)->default_value(-1),
-//                       "Disable processing of ND state 0,1. Select -1 (default) to process both.");
-//
+    desc.add_options()("disable_gate,d",
+                       po::value<uint8_t>()->notifier(
+                           [&ip](size_t in) { ip.active_gates.set(in, false); }),
+                       "Disable processing of ND state 0 or 1.");
+
     desc.add_options()(
         "log_level", po::value<std::string>()->default_value("info")->notifier(
                          [](std::string level) { set_log_level(level); }),
-- 
GitLab