1 /** Copyright © 2021 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
3 * This library is free software; you can redistribute it and/or modify it under
4 * the terms of the GNU Lesser General Public License as published by the Free
5 * Software Foundation; either version 3.0 of the License, or (at your option)
8 * This library is distributed in the hope that it will be useful, but WITHOUT
9 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
13 * You should have received a copy of the GNU Lesser General Public License
14 * along with this library; if not, write to the Free Software Foundation, Inc.,
15 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18 namespace SourceXtractor {
20 template<typename T, size_t N, size_t S>
21 class KdTree<T, N, S>::Node {
23 virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const = 0;
24 virtual ~Node() = default;
27 template<typename T, size_t N, size_t S>
28 class KdTree<T, N, S>::Leaf : public KdTree::Node {
30 explicit Leaf(const std::vector<T>&& data) : m_data(data) {}
31 virtual ~Leaf() = default;
33 virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const {
34 std::vector<T> selection;
35 for (auto& entry : m_data) {
36 double square_dist = 0.0;
37 for (size_t i =0; i < N; i++) {
38 double delta = Traits::getCoord(entry, i) - coord.coord[i];
39 square_dist += delta * delta;
41 if (square_dist < radius*radius) {
42 selection.push_back(entry);
49 const std::vector<T> m_data;
52 template<typename T, size_t N, size_t S>
53 class KdTree<T, N, S>::Split : public KdTree::Node {
55 virtual ~Split() = default;
56 explicit Split(std::vector<T> data, size_t axis) : m_axis(axis) {
57 std::sort(data.begin(), data.end(), [axis](const T& a, const T& b) -> bool {
58 return Traits::getCoord(a, axis) < Traits::getCoord(b, axis);
61 double a = Traits::getCoord(data.at(data.size() / 2 - 1), axis);
62 double b = Traits::getCoord(data.at(data.size() / 2), axis);
65 // avoid a possible rounding issue
68 m_split_value = (a + b) / 2.0;
71 std::vector<T> left(data.begin(), data.begin() + data.size() / 2);
72 std::vector<T> right(data.begin() + data.size() / 2, data.end());
74 if (left.size() > S) {
75 m_left_child = std::make_shared<Split>(std::move(left), (axis+1) % N);
77 m_left_child = std::make_shared<Leaf>(std::move(left));
79 if (right.size() > S) {
80 m_right_child = std::make_shared<Split>(std::move(right), (axis+1) % N);
82 m_right_child = std::make_shared<Leaf>(std::move(right));
86 virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const {
87 if (coord.coord[m_axis] + radius < m_split_value) {
88 return m_left_child->findPointsWithinRadius(coord, radius);
89 } else if (coord.coord[m_axis] - radius > m_split_value) {
90 return m_right_child->findPointsWithinRadius(coord, radius);
92 auto left = m_left_child->findPointsWithinRadius(coord, radius);
93 auto right = m_right_child->findPointsWithinRadius(coord, radius);
96 merge.reserve(left.size() + right.size());
97 merge.insert(merge.end(), left.begin(), left.end());
98 merge.insert(merge.end(), right.begin(), right.end());
106 double m_split_value;
108 std::shared_ptr<Node> m_left_child;
109 std::shared_ptr<Node> m_right_child;
112 template<typename T, size_t N, size_t S>
113 KdTree<T, N, S>::KdTree(const std::vector<T>& data) {
114 if (data.size() > S) {
115 m_root = std::make_shared<Split>(data, 0);
117 std::vector<T> data_copy(data);
118 m_root = std::make_shared<Leaf>(std::move(data_copy));
122 template<typename T, size_t N, size_t S>
123 std::vector<T> KdTree<T, N, S>::findPointsWithinRadius(Coord coord, double radius) const {
124 return m_root->findPointsWithinRadius(coord, radius);