All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator
BallTreeRRTstar.cpp
00001 /*********************************************************************
00002 * Software License Agreement (BSD License)
00003 *
00004 *  Copyright (c) 2011, Rice University
00005 *  All rights reserved.
00006 *
00007 *  Redistribution and use in source and binary forms, with or without
00008 *  modification, are permitted provided that the following conditions
00009 *  are met:
00010 *
00011 *   * Redistributions of source code must retain the above copyright
00012 *     notice, this list of conditions and the following disclaimer.
00013 *   * Redistributions in binary form must reproduce the above
00014 *     copyright notice, this list of conditions and the following
00015 *     disclaimer in the documentation and/or other materials provided
00016 *     with the distribution.
00017 *   * Neither the name of the Rice University nor the names of its
00018 *     contributors may be used to endorse or promote products derived
00019 *     from this software without specific prior written permission.
00020 *
00021 *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00022 *  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00023 *  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00024 *  FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00025 *  COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00026 *  INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00027 *  BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00028 *  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00029 *  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00030 *  LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00031 *  ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00032 *  POSSIBILITY OF SUCH DAMAGE.
00033 *********************************************************************/
00034 
00035 /* Authors: Alejandro Perez, Sertac Karaman, Ioan Sucan */
00036 
00037 #include "ompl/contrib/rrt_star/BallTreeRRTstar.h"
00038 #include "ompl/base/GoalSampleableRegion.h"
00039 #include "ompl/datastructures/NearestNeighborsSqrtApprox.h"
00040 #include "ompl/tools/config/SelfConfig.h"
00041 #include <algorithm>
00042 #include <limits>
00043 #include <map>
00044 
00045 void ompl::geometric::BallTreeRRTstar::setup(void)
00046 {
00047     Planner::setup();
00048     SelfConfig sc(si_, getName());
00049     sc.configurePlannerRange(maxDistance_);
00050 
00051     ballRadiusMax_ = si_->getMaximumExtent();
00052     ballRadiusConst_ = maxDistance_ * sqrt(si_->getStateSpace()->getDimension());
00053 
00054     delayCC_ = true;
00055 
00056     if (!nn_)
00057         nn_.reset(new NearestNeighborsSqrtApprox<Motion*>());
00058     nn_->setDistanceFunction(boost::bind(&BallTreeRRTstar::distanceFunction, this, _1, _2));
00059 }
00060 
00061 void ompl::geometric::BallTreeRRTstar::clear(void)
00062 {
00063     Planner::clear();
00064     sampler_.reset();
00065     motions_.clear();
00066     freeMemory();
00067     if (nn_)
00068         nn_->clear();
00069 }
00070 
00071 bool ompl::geometric::BallTreeRRTstar::solve(const base::PlannerTerminationCondition &ptc)
00072 {
00073     checkValidity();
00074     base::Goal                 *goal   = pdef_->getGoal().get();
00075     base::GoalSampleableRegion *goal_s = dynamic_cast<base::GoalSampleableRegion*>(goal);
00076 
00077     if (!goal)
00078     {
00079         msg_.error("Goal undefined");
00080         return false;
00081     }
00082 
00083     while (const base::State *st = pis_.nextStart())
00084     {
00085         Motion *motion = new Motion(si_, rO_);
00086         si_->copyState(motion->state, st);
00087         addMotion(motion);
00088     }
00089 
00090     if (nn_->size() == 0)
00091     {
00092         msg_.error("There are no valid initial states!");
00093         return false;
00094     }
00095 
00096     if (!sampler_)
00097         sampler_ = si_->allocStateSampler();
00098 
00099     msg_.inform("Starting with %u states", nn_->size());
00100 
00101     Motion *solution     = NULL;
00102     Motion *approxsol    = NULL;
00103     double  approxdif    = std::numeric_limits<double>::infinity();
00104     bool    approxsolved = false;
00105 
00106     Motion *rmotion   = new Motion(si_, rO_);
00107     Motion *toTrim    = NULL;
00108     base::State *rstate = rmotion->state;
00109     base::State *xstate = si_->allocState();
00110     base::State *tstate = si_->allocState();
00111     std::vector<Motion*> solCheck;
00112     std::vector<Motion*> nbh;
00113     std::vector<double>  dists;
00114     std::vector<int>     valid;
00115     long unsigned int    rewireTest = 0;
00116 
00117     std::pair<base::State*,double> lastValid(tstate, 0.0);
00118 
00119     while (ptc() == false)
00120     {
00121         bool rejected = false;
00122 
00123         /* sample until a state not within any of the existing volumes is found */
00124         do
00125         {
00126             /* sample random state (with goal biasing) */
00127             if (goal_s && rng_.uniform01() < goalBias_ && goal_s->canSample())
00128                 goal_s->sampleGoal(rstate);
00129             else
00130                 sampler_->sampleUniform(rstate);
00131 
00132             /* check to see if it is inside an existing volume */
00133             if (inVolume(rstate))
00134             {
00135                 rejected = true;
00136 
00137                 /* see if the state is valid */
00138                 if(!si_->isValid(rstate))
00139                 {
00140                     /* if it's not, reduce the size of the nearest volume to the distance
00141                        between its center and the rejected state */
00142                     toTrim = nn_->nearest(rmotion);
00143                     double newRad = si_->distance(toTrim->state, rstate);
00144                     if (newRad < toTrim->volRadius)
00145                         toTrim->volRadius = newRad;
00146                 }
00147 
00148             }
00149             else
00150 
00151                 rejected = false;
00152 
00153         }
00154         while (rejected);
00155 
00156         /* find closest state in the tree */
00157         Motion *nmotion = nn_->nearest(rmotion);
00158 
00159         base::State *dstate = rstate;
00160 
00161         /* find state to add */
00162         double d = si_->distance(nmotion->state, rstate);
00163         if (d > maxDistance_)
00164         {
00165             si_->getStateSpace()->interpolate(nmotion->state, rstate, maxDistance_ / d, xstate);
00166             dstate = xstate;
00167         }
00168 
00169         if (si_->checkMotion(nmotion->state, dstate, lastValid))
00170         {
00171             /* create a motion */
00172             double distN = si_->distance(dstate, nmotion->state);
00173             Motion *motion = new Motion(si_, rO_);
00174             si_->copyState(motion->state, dstate);
00175             motion->parent = nmotion;
00176             motion->cost = nmotion->cost + distN;
00177 
00178             /* find nearby neighbors */
00179             double r = std::min(ballRadiusConst_ * (sqrt(log((double)(1 + nn_->size())) / ((double)(nn_->size())))),
00180                                 ballRadiusMax_);
00181 
00182             nn_->nearestR(motion, r, nbh);
00183             rewireTest += nbh.size();
00184 
00185             // cache for distance computations
00186             dists.resize(nbh.size());
00187             // cache for motion validity
00188             valid.resize(nbh.size());
00189             std::fill(valid.begin(), valid.end(), 0);
00190 
00191             if (delayCC_)
00192             {
00193                 // calculate all costs and distances
00194                 for (unsigned int i = 0 ; i < nbh.size() ; ++i)
00195                     if (nbh[i] != nmotion)
00196                     {
00197                         double c = nbh[i]->cost + si_->distance(nbh[i]->state, dstate);
00198                         nbh[i]->cost = c;
00199                     }
00200 
00201                 // sort the nodes
00202                 std::sort(nbh.begin(), nbh.end(), compareMotion);
00203 
00204                 for (unsigned int i = 0 ; i < nbh.size() ; ++i)
00205                     if (nbh[i] != nmotion)
00206                     {
00207                         dists[i] = si_->distance(nbh[i]->state, dstate);
00208                         nbh[i]->cost -= dists[i];
00209                     }
00210                 // collision check until a valid motion is found
00211                 for (unsigned int i = 0 ; i < nbh.size() ; ++i)
00212                     if (nbh[i] != nmotion)
00213                     {
00214 
00215                         dists[i] = si_->distance(nbh[i]->state, dstate);
00216                         double c = nbh[i]->cost + dists[i];
00217                         if (c < motion->cost)
00218                         {
00219                             if (si_->checkMotion(nbh[i]->state, dstate, lastValid))
00220                             {
00221                                 motion->cost = c;
00222                                 motion->parent = nbh[i];
00223                                 valid[i] = 1;
00224                                 break;
00225                             }
00226                             else
00227                             {
00228                                 valid[i] = -1;
00229                                 /* if a collision is found, trim radius to distance from motion to last valid state */
00230                                 double nR = si_->distance(nbh[i]->state, lastValid.first);
00231                                 if (nR < nbh[i]->volRadius)
00232                                     nbh[i]->volRadius = nR;
00233                             }
00234                         }
00235                     }
00236                     else
00237                     {
00238                         valid[i] = 1;
00239                         dists[i] = distN;
00240                         break;
00241                     }
00242 
00243             }
00244             else{
00245                 /* find which one we connect the new state to*/
00246                 for (unsigned int i = 0 ; i < nbh.size() ; ++i)
00247                     if (nbh[i] != nmotion)
00248                     {
00249 
00250                         dists[i] = si_->distance(nbh[i]->state, dstate);
00251                         double c = nbh[i]->cost + dists[i];
00252                         if (c < motion->cost)
00253                         {
00254                             if (si_->checkMotion(nbh[i]->state, dstate, lastValid))
00255                             {
00256                                 motion->cost = c;
00257                                 motion->parent = nbh[i];
00258                                 valid[i] = 1;
00259                             }
00260                             else
00261                             {
00262                                 valid[i] = -1;
00263                                 /* if a collision is found, trim radius to distance from motion to last valid state */
00264                                 double newR = si_->distance(nbh[i]->state, lastValid.first);
00265                                 if (newR < nbh[i]->volRadius)
00266                                     nbh[i]->volRadius = newR;
00267 
00268                             }
00269                         }
00270                     }
00271                     else
00272                     {
00273                         valid[i] = 1;
00274                         dists[i] = distN;
00275                     }
00276             }
00277 
00278             /* add motion to tree */
00279             addMotion(motion);
00280 
00281             solCheck.resize(1);
00282             solCheck[0] = motion;
00283 
00284             /* rewire tree if needed */
00285             for (unsigned int i = 0 ; i < nbh.size() ; ++i)
00286                 if (nbh[i] != motion->parent)
00287                 {
00288                     double c = motion->cost + dists[i];
00289                     if (c < nbh[i]->cost)
00290                     {
00291                         bool v = false;
00292                         if (valid[i] == 0)
00293                         {
00294                             if(!si_->checkMotion(nbh[i]->state, dstate, lastValid))
00295                             {
00296                                 /* if a collision is found, trim radius to distance from motion to last valid state */
00297                                 double R =  si_->distance(nbh[i]->state, lastValid.first);
00298                                 if (R < nbh[i]->volRadius)
00299                                     nbh[i]->volRadius = R;
00300                             }
00301                             else
00302                             {
00303                                 v = true;
00304                             }
00305                         }
00306                         if (valid[i] == 1)
00307                             v = true;
00308 
00309                         if (v)
00310                         {
00311                             nbh[i]->parent = motion;
00312                             nbh[i]->cost = c;
00313                             solCheck.push_back(nbh[i]);
00314                         }
00315                     }
00316                 }
00317 
00318             /* check if  we found a solution */
00319             for (unsigned int i = 0 ; i < solCheck.size() ; ++i)
00320             {
00321                 double dist = 0.0;
00322                 bool solved = goal->isSatisfied(solCheck[i]->state, &dist);
00323                 bool sufficientlyShort = solved ? goal->isPathLengthSatisfied(solCheck[i]->cost) : false;
00324 
00325                 if (solved)
00326                 {
00327                     if (sufficientlyShort)
00328                     {
00329                         solution = solCheck[i];
00330                         break;
00331                     }
00332                     else
00333                     {
00334                         if (approxsolved)
00335                         {
00336                             if (dist < approxdif)
00337                             {
00338                                 approxdif = dist;
00339                                 approxsol = solCheck[i];
00340                             }
00341                         }
00342                         else
00343                         {
00344                             approxsolved = true;
00345                             approxdif = dist;
00346                             approxsol = solCheck[i];
00347                         }
00348                     }
00349                 }
00350                 else
00351                     if (!approxsolved && dist < approxdif)
00352                     {
00353                         approxdif = dist;
00354                         approxsol = solCheck[i];
00355                     }
00356             }
00357 
00358             /* terminate if a solution was found */
00359             if (solution != NULL)
00360                 break;
00361         }
00362         else
00363         {
00364             /* if a collision is found, trim radius to distance from motion to last valid state */
00365             toTrim = nn_->nearest(nmotion);
00366             double newRadius =  si_->distance(toTrim->state, lastValid.first);
00367             if (newRadius < toTrim->volRadius)
00368                 toTrim->volRadius = newRadius;
00369         }
00370     }
00371 
00372     bool approximate = false;
00373     if (solution == NULL)
00374     {
00375         solution = approxsol;
00376         approximate = true;
00377     }
00378 
00379     if (solution != NULL)
00380     {
00381         /* construct the solution path */
00382         std::vector<Motion*> mpath;
00383         while (solution != NULL)
00384         {
00385             mpath.push_back(solution);
00386             solution = solution->parent;
00387         }
00388 
00389         /* set the solution path */
00390         PathGeometric *path = new PathGeometric(si_);
00391         for (int i = mpath.size() - 1 ; i >= 0 ; --i)
00392             path->states.push_back(si_->cloneState(mpath[i]->state));
00393         goal->setDifference(approxdif);
00394         goal->setSolutionPath(base::PathPtr(path), approximate);
00395 
00396         if (approximate)
00397             msg_.warn("Found approximate solution");
00398     }
00399 
00400     si_->freeState(xstate);
00401     if (rmotion->state)
00402         si_->freeState(rmotion->state);
00403     delete rmotion;
00404 
00405     msg_.inform("Created %u states. Checked %lu rewire options.", nn_->size(), rewireTest);
00406 
00407     return goal->isAchieved();
00408 }
00409 
00410 void ompl::geometric::BallTreeRRTstar::freeMemory(void)
00411 {
00412     if (nn_)
00413     {
00414         std::vector<Motion*> motions;
00415         nn_->list(motions);
00416         for (unsigned int i = 0 ; i < motions.size() ; ++i)
00417         {
00418             if (motions[i]->state)
00419                 si_->freeState(motions[i]->state);
00420             delete motions[i];
00421         }
00422     }
00423 }
00424 
00425 void ompl::geometric::BallTreeRRTstar::getPlannerData(base::PlannerData &data) const
00426 {
00427     Planner::getPlannerData(data);
00428 
00429     std::vector<Motion*> motions;
00430     if (nn_)
00431         nn_->list(motions);
00432 
00433     for (unsigned int i = 0 ; i < motions.size() ; ++i)
00434         data.recordEdge(motions[i]->parent ? motions[i]->parent->state : NULL, motions[i]->state);
00435 }