pybind_utils.h 3.67 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
#ifndef DUCC0_PYBIND_UTILS_H
#define DUCC0_PYBIND_UTILS_H
3
4
5
6

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

7
#include "mr_util/infra/mav.h"
8
9
10
11
12

namespace mr {

namespace detail_pybind {

Martin Reinecke's avatar
Martin Reinecke committed
13
14
15
using shape_t=fmav_info::shape_t;
using stride_t=fmav_info::stride_t;

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
namespace py = pybind11;

template<typename T> bool isPyarr(const py::object &obj)
  { return py::isinstance<py::array_t<T>>(obj); }

template<typename T> py::array_t<T> toPyarr(const py::object &obj)
  {
  auto tmp = obj.cast<py::array_t<T>>();
  MR_assert(tmp.is(obj), "error during array conversion");
  return tmp;
  }

shape_t copy_shape(const py::array &arr)
  {
  shape_t res(size_t(arr.ndim()));
  for (size_t i=0; i<res.size(); ++i)
    res[i] = size_t(arr.shape(int(i)));
  return res;
  }

template<typename T> stride_t copy_strides(const py::array &arr)
  {
  stride_t res(size_t(arr.ndim()));
  constexpr auto st = ptrdiff_t(sizeof(T));
  for (size_t i=0; i<res.size(); ++i)
    {
    auto tmp = arr.strides(int(i));
    MR_assert((tmp/st)*st==tmp, "bad stride");
    res[i] = tmp/st;
    }
  return res;
  }

template<size_t ndim> std::array<size_t, ndim> copy_fixshape(const py::array &arr)
  {
  MR_assert(size_t(arr.ndim())==ndim, "incorrect number of dimensions");
  std::array<size_t, ndim> res;
  for (size_t i=0; i<ndim; ++i)
    res[i] = size_t(arr.shape(int(i)));
  return res;
  }

template<typename T, size_t ndim> std::array<ptrdiff_t, ndim> copy_fixstrides(const py::array &arr)
  {
  MR_assert(size_t(arr.ndim())==ndim, "incorrect number of dimensions");
  std::array<ptrdiff_t, ndim> res;
  constexpr auto st = ptrdiff_t(sizeof(T));
  for (size_t i=0; i<ndim; ++i)
    {
    auto tmp = arr.strides(int(i));
    MR_assert((tmp/st)*st==tmp, "bad stride");
    res[i] = tmp/st;
    }
  return res;
  }

template<typename T> py::array_t<T> make_Pyarr(const shape_t &dims)
  { return py::array_t<T>(dims); }

template<typename T> py::array_t<T> get_optional_Pyarr(py::object &arr_,
  const shape_t &dims)
  {
  if (arr_.is_none()) return py::array_t<T>(dims);
  MR_assert(isPyarr<T>(arr_), "incorrect data type");
  auto tmp = toPyarr<T>(arr_);
  MR_assert(dims.size()==size_t(tmp.ndim()), "dimension mismatch");
  for (size_t i=0; i<dims.size(); ++i)
    MR_assert(dims[i]==size_t(tmp.shape(int(i))), "dimension mismatch");
  return tmp;
  }

template<typename T> py::array_t<T> get_optional_const_Pyarr(
  const py::object &arr_, const shape_t &dims)
  {
  if (arr_.is_none()) return py::array_t<T>(shape_t(dims.size(), 0));
  MR_assert(isPyarr<T>(arr_), "incorrect data type");
  auto tmp = toPyarr<T>(arr_);
  MR_assert(dims.size()==size_t(tmp.ndim()), "dimension mismatch");
  for (size_t i=0; i<dims.size(); ++i)
    MR_assert(dims[i]==size_t(tmp.shape(int(i))), "dimension mismatch");
  return tmp;
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
99
template<typename T> fmav<T> to_fmav(const py::object &obj, bool rw=false)
100
101
  {
  auto arr = toPyarr<T>(obj);
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
102
103
104
105
  if (rw)
    return fmav<T>(reinterpret_cast<T *>(arr.mutable_data()),
      copy_shape(arr), copy_strides<T>(arr), true);
  return fmav<T>(reinterpret_cast<const T *>(arr.data()),
106
107
108
    copy_shape(arr), copy_strides<T>(arr));
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
109
template<typename T, size_t ndim> mav<T,ndim> to_mav(const py::array &obj, bool rw=false)
110
111
  {
  auto arr = toPyarr<T>(obj);
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
112
113
114
115
  if (rw)
    return mav<T,ndim>(reinterpret_cast<T *>(arr.mutable_data()),
      copy_fixshape<ndim>(arr), copy_fixstrides<T,ndim>(arr), true);
  return mav<T,ndim>(reinterpret_cast<const T *>(arr.data()),
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    copy_fixshape<ndim>(arr), copy_fixstrides<T,ndim>(arr));
  }

}

using detail_pybind::isPyarr;
using detail_pybind::make_Pyarr;
using detail_pybind::get_optional_Pyarr;
using detail_pybind::get_optional_const_Pyarr;
using detail_pybind::to_fmav;
using detail_pybind::to_mav;

}

#endif