1 #ifndef GRIDINTERPOLATION_IMPL
2 #error Please, include "MathUtils/interpolation/GridInterpolation.h"
5 #include "AlexandriaKernel/Tuples.h"
6 #include "MathUtils/interpolation/interpolation.h"
11 template <typename T, typename Enable = void>
12 struct InterpolationImpl;
15 * Trait for continuous types
18 struct InterpolationImpl<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
19 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values, bool extrapolate) {
20 return simple_interpolation(x, knots, values, extrapolate);
23 template <typename... Rest>
24 static double interpolate(const T x, const std::vector<T>& knots,
25 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool extrapolate,
27 // If no extrapolation, and the value if out-of-bounds, just clip at 0
28 if ((x < knots.front() || x > knots.back()) && !extrapolate) {
32 if (knots.size() == 1) {
33 return (*interpolators[0])(rest...);
36 std::size_t x2i = std::lower_bound(knots.begin(), knots.end(), x) - knots.begin();
39 } else if (x2i == knots.size()) {
42 std::size_t x1i = x2i - 1;
44 double y1 = (*interpolators[x1i])(rest...);
45 double y2 = (*interpolators[x2i])(rest...);
47 return simple_interpolation(x, {knots[x1i], knots[x2i]}, {y1, y2}, extrapolate);
50 static void checkOrder(const std::vector<T>& knots) {
51 if (!std::is_sorted(knots.begin(), knots.end())) {
52 throw InterpolationException("coordinates must be sorted");
58 * Trait for discrete types
61 struct InterpolationImpl<T, typename std::enable_if<!std::is_floating_point<T>::value>::type> {
62 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values, bool /*extrapolate*/) {
63 std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
64 if (i >= knots.size() || knots[i] != x)
69 template <typename... Rest>
70 static double interpolate(const T x, const std::vector<T>& knots,
71 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool, const Rest... rest) {
72 std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
73 if (i >= knots.size() || knots[i] != x)
75 return (*interpolators[i])(rest...);
78 static void checkOrder(const std::vector<T>&) {
79 // Discrete axes do not need to be in order
84 * Specialization (and end of the recursion) for a 1-dimensional interpolation.
92 * A 1-dimensional grid
97 InterpN(const std::tuple<std::vector<T>>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
98 : m_knots(std::get<0>(grid)), m_values(values.begin(), values.end()), m_extrapolate(extrapolate) {
99 if (values.shape().size() != 1) {
100 throw InterpolationException() << "values and coordinates dimensionalities must match: " << values.shape().size() << " != 1";
102 if (m_knots.size() != values.size()) {
103 throw InterpolationException() << "The size of the grid and the size of the values do not match: " << m_knots.size()
104 << " != " << m_values.size();
115 double operator()(const T x) const {
116 return InterpolationImpl<T>::interpolate(x, m_knots, m_values, m_extrapolate);
120 InterpN(const InterpN&) = default;
123 InterpN(InterpN&&) = default;
126 std::vector<T> m_knots;
127 std::vector<double> m_values;
132 * Recursive specialization of an N-Dimensional interpolator
133 * @tparam N Dimensionality (N > 1)
134 * @tparam F The first element of the index sequence
135 * @tparam Rest The rest of the elements from the index sequence
137 template <typename T, typename... Rest>
138 class InterpN<T, Rest...> {
147 InterpN(const std::tuple<std::vector<T>, std::vector<Rest>...>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
148 : m_extrapolate(extrapolate) {
149 constexpr std::size_t N = sizeof...(Rest) + 1;
151 if (values.shape().size() != N) {
152 throw InterpolationException() << "values and coordinates dimensionality must match: " << values.shape().size()
155 m_knots = std::get<0>(grid);
156 InterpolationImpl<T>::checkOrder(m_knots);
157 if (m_knots.size() != values.shape().back()) {
158 throw InterpolationException("coordinates and value sizes must match");
160 // Build nested interpolators
161 auto subgrid = Tuple::Tail(std::move(grid));
162 m_interpolators.resize(m_knots.size());
163 for (size_t i = 0; i < m_knots.size(); ++i) {
164 auto subvalues = values.rslice(i);
165 m_interpolators[i].reset(new InterpN<Rest...>(subgrid, subvalues, extrapolate));
171 * @param x Value for the axis for the first dimension
172 * @param rest Values for the next set of axes
173 * @return The interpolated value
175 * Doubles<Rest>... is used to expand into (N-1) doubles
176 * x is used to find the interpolators for x1 and x2 s.t. x1 <= x <=x2
177 * Those two interpolators are used to compute y1 for x1, and y2 for x2 (based on the rest of the parameters)
178 * A final linear interpolator is used to get the value of y at the position x
180 double operator()(T x, Rest... rest) const {
181 return InterpolationImpl<T>::interpolate(x, m_knots, m_interpolators, m_extrapolate, rest...);
185 InterpN(const InterpN& other) : m_knots(other.m_knots), m_extrapolate(other.m_extrapolate) {
186 m_interpolators.resize(m_knots.size());
187 for (size_t i = 0; i < m_interpolators.size(); ++i) {
188 m_interpolators[i].reset(new InterpN<Rest...>(*other.m_interpolators[i]));
193 std::vector<T> m_knots;
194 std::vector<std::unique_ptr<InterpN<Rest...>>> m_interpolators;
198 } // namespace MathUtils
199 } // namespace Euclid