00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
00038 #define OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
00039
00040 #include "ompl/datastructures/NearestNeighbors.h"
00041 #include "ompl/datastructures/GreedyKCenters.h"
00042 #include "ompl/util/Exception.h"
00043 #include <boost/unordered_set.hpp>
00044 #include <queue>
00045 #include <algorithm>
00046
00047 namespace ompl
00048 {
00049
00058 template<typename _T>
00059 class NearestNeighborsGNAT : public NearestNeighbors<_T>
00060 {
00061 protected:
00062
00063
00064 typedef std::pair<const _T*,double> DataDist;
00065 struct DataDistCompare
00066 {
00067 bool operator()(const DataDist& d0, const DataDist& d1)
00068 {
00069 return d0.second < d1.second;
00070 }
00071 };
00072 typedef std::priority_queue<DataDist, std::vector<DataDist>, DataDistCompare> NearQueue;
00073
00074
00075
00076 class Node;
00077 typedef std::pair<Node*,double> NodeDist;
00078 struct NodeDistCompare
00079 {
00080 bool operator()(const NodeDist& n0, const NodeDist& n1) const
00081 {
00082 return (n0.second - n0.first->maxRadius_) > (n1.second - n1.first->maxRadius_);
00083 }
00084 };
00085 typedef std::priority_queue<NodeDist, std::vector<NodeDist>, NodeDistCompare> NodeQueue;
00086
00087
00088 public:
00089 NearestNeighborsGNAT(unsigned int degree = 4, unsigned int minDegree = 2,
00090 unsigned int maxDegree = 6, unsigned int maxNumPtsPerLeaf = 50,
00091 unsigned int removedCacheSize = 50)
00092 : NearestNeighbors<_T>(), tree_(NULL), degree_(degree),
00093 minDegree_(std::min(degree,minDegree)), maxDegree_(std::max(maxDegree,degree)),
00094 maxNumPtsPerLeaf_(maxNumPtsPerLeaf), size_(0), removedCacheSize_(removedCacheSize)
00095 {
00096 }
00097
00098 virtual ~NearestNeighborsGNAT(void)
00099 {
00100 if (tree_)
00101 delete tree_;
00102 }
00103
00105 virtual void setDistanceFunction(const typename NearestNeighbors<_T>::DistanceFunction &distFun)
00106 {
00107 NearestNeighbors<_T>::setDistanceFunction(distFun);
00108 pivotSelector_.setDistanceFunction(distFun);
00109 }
00110
00111 virtual void clear(void)
00112 {
00113 if (tree_)
00114 {
00115 delete tree_;
00116 tree_ = NULL;
00117 }
00118 size_ = 0;
00119 removed_.clear();
00120 }
00121
00122 virtual void add(const _T &data)
00123 {
00124 if (tree_)
00125 tree_->add(*this, data);
00126 else
00127 {
00128 tree_ = new Node(NULL, degree_, maxNumPtsPerLeaf_, data);
00129 size_ = 1;
00130 }
00131 }
00132
00134 virtual void add(const std::vector<_T> &data)
00135 {
00136 if (tree_)
00137 NearestNeighbors<_T>::add(data);
00138 else if (data.size()>0)
00139 {
00140 tree_ = new Node(NULL, degree_, maxNumPtsPerLeaf_, data[0]);
00141 for (unsigned int i=1; i<data.size(); ++i)
00142 tree_->data_.push_back(data[i]);
00143 if (tree_->needToSplit(*this))
00144 tree_->split(*this);
00145 }
00146 size_ += data.size();
00147 }
00149 void rebuildDataStructure()
00150 {
00151 std::vector<_T> lst;
00152 list(lst);
00153 clear();
00154 add(lst);
00155 }
00156 virtual bool remove(const _T &data)
00157 {
00158 if (!tree_) return false;
00159 NearQueue nbhQueue;
00160
00161 bool isPivot = nearestKInternal(data, 1, nbhQueue);
00162 if (*nbhQueue.top().first != data)
00163 return false;
00164 removed_.insert(nbhQueue.top().first);
00165 size_--;
00166
00167
00168 if (isPivot || removed_.size()>=removedCacheSize_)
00169 rebuildDataStructure();
00170 return true;
00171 }
00172 virtual _T nearest(const _T &data) const
00173 {
00174 if (tree_)
00175 {
00176 std::vector<_T> nbh;
00177 nearestK(data, 1, nbh);
00178 if (!nbh.empty()) return nbh[0];
00179 }
00180 throw Exception("No elements found");
00181 }
00182
00183 virtual void nearestK(const _T &data, std::size_t k, std::vector<_T> &nbh) const
00184 {
00185 nbh.clear();
00186 if (k == 0) return;
00187 if (tree_)
00188 {
00189 NearQueue nbhQueue;
00190 nearestKInternal(data, k, nbhQueue);
00191 postprocessNearest(nbhQueue, nbh, k);
00192 }
00193 }
00194
00195 virtual void nearestR(const _T &data, double radius, std::vector<_T> &nbh) const
00196 {
00197 nbh.clear();
00198 if (tree_)
00199 {
00200 NearQueue nbhQueue;
00201 nearestRInternal(data, radius, nbhQueue);
00202 postprocessNearest(nbhQueue, nbh);
00203 }
00204 }
00205
00206 virtual std::size_t size(void) const
00207 {
00208 return size_;
00209 }
00210
00211 virtual void list(std::vector<_T> &data) const
00212 {
00213 data.clear();
00214 data.reserve(size());
00215 if (tree_)
00216 tree_->list(*this, data);
00217 }
00218
00219 friend std::ostream& operator<<(std::ostream& out, const NearestNeighborsGNAT<_T>& gnat)
00220 {
00221 if (gnat.tree_)
00222 {
00223 out << *gnat.tree_;
00224 if (!gnat.removed_.empty())
00225 {
00226 out << "Elements marked for removal:\n";
00227 for (typename boost::unordered_set<const _T*>::const_iterator it = gnat.removed_.begin();
00228 it != gnat.removed_.end(); it++)
00229 out << **it << '\t';
00230 out << std::endl;
00231 }
00232 }
00233 return out;
00234 }
00235
00236
00237 void integrityCheck()
00238 {
00239 std::vector<_T> lst;
00240 boost::unordered_set<const _T*> tmp;
00241
00242 removed_.swap(tmp);
00243 list(lst);
00244
00245 for (typename boost::unordered_set<const _T*>::iterator it=tmp.begin(); it!=tmp.end(); it++)
00246 {
00247 unsigned int i;
00248 for (i=0; i<lst.size(); ++i)
00249 if (lst[i]==**it)
00250 break;
00251 if (i == lst.size())
00252 {
00253
00254 std::cout << "***** FAIL!! ******\n" << *this << '\n';
00255 for (unsigned int j=0; j<lst.size(); ++j) std::cout<<lst[j]<<'\t';
00256 std::cout<<std::endl;
00257 }
00258 assert(i != lst.size());
00259 }
00260
00261 removed_.swap(tmp);
00262
00263 list(lst);
00264 if (lst.size() != size_)
00265 std::cout << "#########################################\n" << *this << std::endl;
00266 assert(lst.size() == size_);
00267 }
00268 protected:
00269 typedef NearestNeighborsGNAT<_T> GNAT;
00270
00271 bool isRemoved(const _T& data) const
00272 {
00273 return !removed_.empty() && removed_.find(&data) != removed_.end();
00274 }
00275
00276
00277 bool nearestKInternal(const _T &data, std::size_t k, NearQueue& nbhQueue) const
00278 {
00279 bool isPivot;
00280 double dist;
00281 NodeDist nodeDist;
00282 NodeQueue nodeQueue;
00283
00284 isPivot = tree_->insertNeighborK(nbhQueue, k, tree_->pivot_, data,
00285 NearestNeighbors<_T>::distFun_(data, tree_->pivot_));
00286 tree_->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
00287 while (nodeQueue.size() > 0)
00288 {
00289 dist = nbhQueue.top().second;
00290 nodeDist = nodeQueue.top();
00291 nodeQueue.pop();
00292 if (nbhQueue.size() == k &&
00293 (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
00294 nodeDist.second < nodeDist.first->minRadius_ - dist))
00295 break;
00296 nodeDist.first->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
00297 }
00298 return isPivot;
00299 }
00300 void nearestRInternal(const _T &data, double radius, NearQueue& nbhQueue) const
00301 {
00302 double dist = radius;
00303 NodeQueue nodeQueue;
00304 NodeDist nodeDist;
00305
00306 tree_->insertNeighborR(nbhQueue, radius, tree_->pivot_,
00307 NearestNeighbors<_T>::distFun_(data, tree_->pivot_));
00308 tree_->nearestR(*this, data, radius, nbhQueue, nodeQueue);
00309 while (nodeQueue.size() > 0)
00310 {
00311 nodeDist = nodeQueue.top();
00312 nodeQueue.pop();
00313 if (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
00314 nodeDist.second < nodeDist.first->minRadius_ - dist)
00315 break;
00316 nodeDist.first->nearestR(*this, data, radius, nbhQueue, nodeQueue);
00317 }
00318 }
00319 void postprocessNearest(NearQueue& nbhQueue, std::vector<_T> &nbh,
00320 unsigned int k=std::numeric_limits<unsigned int>::max()) const
00321 {
00322 typename std::vector<_T>::reverse_iterator it;
00323 nbh.resize(nbhQueue.size());
00324 for (it=nbh.rbegin(); it!=nbh.rend(); it++, nbhQueue.pop())
00325 *it = *nbhQueue.top().first;
00326 }
00327
00328 class Node
00329 {
00330 public:
00331 Node(const Node* parent, int degree, int capacity, const _T& pivot)
00332 : degree_(degree), pivot_(pivot),
00333 minRadius_(std::numeric_limits<double>::infinity()),
00334 maxRadius_(-minRadius_), minRange_(degree, minRadius_),
00335 maxRange_(degree, maxRadius_)
00336 {
00337
00338 data_.reserve(capacity+1);
00339 }
00340
00341 ~Node()
00342 {
00343 for (unsigned int i=0; i<children_.size(); ++i)
00344 delete children_[i];
00345 }
00346
00347 void add(GNAT& gnat, const _T& data)
00348 {
00349 if (children_.size()==0)
00350 {
00351 data_.push_back(data);
00352 gnat.size_++;
00353 if (needToSplit(gnat))
00354 {
00355 if (gnat.removed_.size() > 0)
00356 gnat.rebuildDataStructure();
00357 else
00358 split(gnat);
00359 }
00360 }
00361 else
00362 {
00363 std::vector<double> dist(children_.size());
00364 double minDist = std::numeric_limits<double>::infinity();
00365 int minInd = -1;
00366
00367 for (unsigned int i=0; i<children_.size(); ++i)
00368 if ((dist[i] = gnat.distFun_(data, children_[i]->pivot_)) < minDist)
00369 {
00370 minDist = dist[i];
00371 minInd = i;
00372 }
00373 for (unsigned int i=0; i<children_.size(); ++i)
00374 {
00375 if (children_[i]->minRange_[minInd] > dist[i])
00376 children_[i]->minRange_[minInd] = dist[i];
00377 if (children_[i]->maxRange_[minInd] < dist[i])
00378 children_[i]->maxRange_[minInd] = dist[i];
00379 }
00380 if (minDist < children_[minInd]->minRadius_)
00381 children_[minInd]->minRadius_ = minDist;
00382 if (minDist > children_[minInd]->maxRadius_)
00383 children_[minInd]->maxRadius_ = minDist;
00384
00385 children_[minInd]->add(gnat, data);
00386 }
00387 }
00388
00389 bool needToSplit(const GNAT& gnat) const
00390 {
00391 unsigned int sz = data_.size();
00392 return sz > gnat.maxNumPtsPerLeaf_ && sz > degree_;
00393 }
00394 void split(GNAT& gnat)
00395 {
00396 std::vector<std::vector<double> > dists;
00397 std::vector<unsigned int> pivots;
00398
00399 children_.reserve(degree_);
00400 gnat.pivotSelector_.kcenters(data_, degree_, pivots, dists);
00401 for(unsigned int i=0; i<pivots.size(); i++)
00402 children_.push_back(new Node(this, degree_, gnat.maxNumPtsPerLeaf_, data_[pivots[i]]));
00403 degree_ = pivots.size();
00404 for (unsigned int j=0; j<data_.size(); ++j)
00405 {
00406 unsigned int k = 0;
00407 for (unsigned int i=1; i<degree_; ++i)
00408 if (dists[j][i] < dists[j][k])
00409 k = i;
00410 Node* child = children_[k];
00411 if (j != pivots[k])
00412 {
00413 child->data_.push_back(data_[j]);
00414 if (dists[j][k] > child->maxRadius_)
00415 child->maxRadius_ = dists[j][k];
00416 if (dists[j][k] < child->minRadius_)
00417 child->minRadius_ = dists[j][k];
00418 }
00419 for (unsigned int i=0; i<degree_; ++i)
00420 {
00421 if (children_[i]->minRange_[k] > dists[j][i])
00422 children_[i]->minRange_[k] = dists[j][i];
00423 if (children_[i]->maxRange_[k] < dists[j][i])
00424 children_[i]->maxRange_[k] = dists[j][i];
00425 }
00426 }
00427
00428 for (unsigned int i=0; i<degree_; ++i)
00429 {
00430
00431 children_[i]->degree_ = std::min(std::max(
00432 degree_ * (unsigned int)(children_[i]->data_.size() / data_.size()),
00433 gnat.minDegree_), gnat.maxDegree_);
00434
00435 if (children_[i]->minRadius_ == std::numeric_limits<double>::infinity())
00436 children_[i]->minRadius_ = children_[i]->maxRadius_ = 0.;
00437 }
00438
00439 std::vector<_T> tmp;
00440 data_.swap(tmp);
00441
00442 for (unsigned int i=0; i<degree_; ++i)
00443 if (children_[i]->needToSplit(gnat))
00444 children_[i]->split(gnat);
00445 }
00446
00447
00448 bool insertNeighborK(NearQueue& nbh, std::size_t k, const _T& data, const _T& key, double dist) const
00449 {
00450 if (nbh.size() < k)
00451 {
00452 nbh.push(std::make_pair(&data, dist));
00453 return true;
00454 }
00455 else if (dist < nbh.top().second ||
00456 (dist < std::numeric_limits<double>::epsilon() && data==key))
00457 {
00458 nbh.pop();
00459 nbh.push(std::make_pair(&data, dist));
00460 return true;
00461 }
00462 return false;
00463 }
00464
00465
00466 void nearestK(const GNAT& gnat, const _T &data, std::size_t k,
00467 NearQueue& nbh, NodeQueue& nodeQueue, bool& isPivot) const
00468 {
00469 for (unsigned int i=0; i<data_.size(); ++i)
00470 if (!gnat.isRemoved(data_[i]))
00471 {
00472 if (insertNeighborK(nbh, k, data_[i], data, gnat.distFun_(data, data_[i])))
00473 isPivot = false;
00474 }
00475 if (children_.size() > 0)
00476 {
00477 double dist;
00478 Node* child;
00479 std::vector<double> distToPivot(children_.size());
00480 std::vector<int> permutation(children_.size());
00481
00482 for (unsigned int i=0; i<permutation.size(); ++i)
00483 permutation[i] = i;
00484 std::random_shuffle(permutation.begin(), permutation.end());
00485
00486 for (unsigned int i=0; i<children_.size(); ++i)
00487 if (permutation[i] >= 0)
00488 {
00489 child = children_[permutation[i]];
00490 distToPivot[permutation[i]] = gnat.distFun_(data, child->pivot_);
00491 if (insertNeighborK(nbh, k, child->pivot_, data, distToPivot[permutation[i]]))
00492 isPivot = true;
00493 if (nbh.size()==k)
00494 {
00495 dist = nbh.top().second;
00496 for (unsigned int j=0; j<children_.size(); ++j)
00497 if (permutation[j] >=0 && i != j &&
00498 (distToPivot[permutation[i]] - dist > child->maxRange_[permutation[j]] ||
00499 distToPivot[permutation[i]] + dist < child->minRange_[permutation[j]]))
00500 permutation[j] = -1;
00501 }
00502 }
00503
00504 dist = nbh.top().second;
00505 for (unsigned int i=0; i<children_.size(); ++i)
00506 if (permutation[i] >= 0)
00507 {
00508 child = children_[permutation[i]];
00509 if (nbh.size()<k ||
00510 (distToPivot[permutation[i]] - dist <= child->maxRadius_ &&
00511 distToPivot[permutation[i]] + dist >= child->minRadius_))
00512 nodeQueue.push(std::make_pair(child, distToPivot[permutation[i]]));
00513 }
00514 }
00515 }
00516
00517 void insertNeighborR(NearQueue& nbh, double r, const _T& data, double dist) const
00518 {
00519 if (dist <= r)
00520 nbh.push(std::make_pair(&data, dist));
00521 }
00522
00523 void nearestR(const GNAT& gnat, const _T &data, double r, NearQueue& nbh, NodeQueue& nodeQueue) const
00524 {
00525 double dist = r;
00526
00527 for (unsigned int i=0; i<data_.size(); ++i)
00528 if (!gnat.isRemoved(data_[i]))
00529 insertNeighborR(nbh, r, data_[i], gnat.distFun_(data, data_[i]));
00530 if (children_.size() > 0)
00531 {
00532 Node* child;
00533 std::vector<double> distToPivot(children_.size());
00534 std::vector<int> permutation(children_.size());
00535
00536 for (unsigned int i=0; i<permutation.size(); ++i)
00537 permutation[i] = i;
00538 std::random_shuffle(permutation.begin(), permutation.end());
00539
00540 for (unsigned int i=0; i<children_.size(); ++i)
00541 if (permutation[i] >= 0)
00542 {
00543 child = children_[permutation[i]];
00544 distToPivot[i] = gnat.distFun_(data, child->pivot_);
00545 insertNeighborR(nbh, r, child->pivot_, distToPivot[i]);
00546 for (unsigned int j=0; j<children_.size(); ++j)
00547 if (permutation[j] >=0 && i != j &&
00548 (distToPivot[i] - dist > child->maxRange_[permutation[j]] ||
00549 distToPivot[i] + dist < child->minRange_[permutation[j]]))
00550 permutation[j] = -1;
00551 }
00552
00553 for (unsigned int i=0; i<children_.size(); ++i)
00554 if (permutation[i] >= 0)
00555 {
00556 child = children_[permutation[i]];
00557 if (distToPivot[i] - dist <= child->maxRadius_ &&
00558 distToPivot[i] + dist >= child->minRadius_)
00559 nodeQueue.push(std::make_pair(child, distToPivot[i]));
00560 }
00561 }
00562 }
00563
00564 void list(const GNAT& gnat, std::vector<_T> &data) const
00565 {
00566 if (!gnat.isRemoved(pivot_))
00567 data.push_back(pivot_);
00568 for (unsigned int i=0; i<data_.size(); ++i)
00569 if(!gnat.isRemoved(data_[i]))
00570 data.push_back(data_[i]);
00571 for (unsigned int i=0; i<children_.size(); ++i)
00572 children_[i]->list(gnat, data);
00573 }
00574
00575 friend std::ostream& operator<<(std::ostream& out, const Node& node)
00576 {
00577 out << "\ndegree:\t" << node.degree_;
00578 out << "\nminRadius:\t" << node.minRadius_;
00579 out << "\nmaxRadius:\t" << node.maxRadius_;
00580 out << "\nminRange:\t";
00581 for (unsigned int i=0; i<node.minRange_.size(); ++i)
00582 out << node.minRange_[i] << '\t';
00583 out << "\nmaxRange: ";
00584 for (unsigned int i=0; i<node.maxRange_.size(); ++i)
00585 out << node.maxRange_[i] << '\t';
00586 out << "\npivot:\t" << node.pivot_;
00587 out << "\ndata: ";
00588 for (unsigned int i=0; i<node.data_.size(); ++i)
00589 out << node.data_[i] << '\t';
00590 out << "\nthis:\t" << &node;
00591 out << "\nchildren:\n";
00592 for (unsigned int i=0; i<node.children_.size(); ++i)
00593 out << node.children_[i] << '\t';
00594 out << '\n';
00595 for (unsigned int i=0; i<node.children_.size(); ++i)
00596 out << *node.children_[i] << '\n';
00597 return out;
00598 }
00599
00600 unsigned int degree_;
00601 const _T pivot_;
00602 double minRadius_;
00603 double maxRadius_;
00604 std::vector<double> minRange_;
00605 std::vector<double> maxRange_;
00606 std::vector<_T> data_;
00607 std::vector<Node*> children_;
00608 };
00609
00610
00612 Node *tree_;
00613
00614 unsigned int degree_;
00615 unsigned int minDegree_;
00616 unsigned int maxDegree_;
00617 unsigned int maxNumPtsPerLeaf_;
00618 std::size_t size_;
00619 std::size_t removedCacheSize_;
00620
00622 GreedyKCenters<_T> pivotSelector_;
00623
00625 boost::unordered_set<const _T*> removed_;
00626 };
00627
00628 }
00629
00630 #endif