[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_common.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RF_COMMON_HXX
00038 #define VIGRA_RF_COMMON_HXX
00039 
00040 namespace vigra
00041 {
00042 
00043 
00044 struct ClassificationTag
00045 {};
00046 
00047 struct RegressionTag
00048 {};
00049 
00050 namespace detail
00051 {
00052     class RF_DEFAULT;
00053 }
00054 inline detail::RF_DEFAULT& rf_default();
00055 namespace detail
00056 {
00057 
00058 /**\brief singleton default tag class -
00059  *
00060  *  use the rf_default() factory function to use the tag.
00061  *  \sa RandomForest<>::learn();
00062  */
00063 class RF_DEFAULT
00064 {
00065     private:
00066         RF_DEFAULT()
00067         {}
00068     public:
00069         friend RF_DEFAULT& ::vigra::rf_default();
00070 
00071         /** ok workaround for automatic choice of the decisiontree
00072          * stackentry.
00073          */
00074 };
00075 
00076 /**\brief chooses between default type and type supplied
00077  * 
00078  * This is an internal class and you shouldn't really care about it.
00079  * Just pass on used in RandomForest.learn()
00080  * Usage:
00081  *\code
00082  *      // example: use container type supplied by user or ArrayVector if 
00083  *      //          rf_default() was specified as argument;
00084  *      template<class Container_t>
00085  *      void do_some_foo(Container_t in)
00086  *      {
00087  *          typedef ArrayVector<int>    Default_Container_t;
00088  *          Default_Container_t         default_value;
00089  *          Value_Chooser<Container_t,  Default_Container_t> 
00090  *                      choose(in, default_value);
00091  *
00092  *          // if the user didn't care and the in was of type 
00093  *          // RF_DEFAULT then default_value is used.
00094  *          do_some_more_foo(choose.value());
00095  *      }
00096  *      Value_Chooser choose_val<Type, Default_Type>
00097  *\endcode
00098  */
00099 template<class T, class C>
00100 class Value_Chooser
00101 {
00102 public:
00103     typedef T type;
00104     static T & choose(T & t, C &)
00105     {
00106         return t; 
00107     }
00108 };
00109 
00110 template<class C>
00111 class Value_Chooser<detail::RF_DEFAULT, C>
00112 {
00113 public:
00114     typedef C type;
00115     
00116     static C & choose(detail::RF_DEFAULT &, C & c)
00117     {
00118         return c; 
00119     }
00120 };
00121 
00122 
00123 
00124 
00125 } //namespace detail
00126 
00127 
00128 /**\brief factory function to return a RF_DEFAULT tag
00129  * \sa RandomForest<>::learn()
00130  */
00131 detail::RF_DEFAULT& rf_default()
00132 {
00133     static detail::RF_DEFAULT result;
00134     return result;
00135 }
00136 
00137 /** tags used with the RandomForestOptions class
00138  * \sa RF_Traits::Option_t
00139  */
00140 enum RF_OptionTag   { RF_EQUAL,
00141                       RF_PROPORTIONAL,
00142                       RF_EXTERNAL,
00143                       RF_NONE,
00144                       RF_FUNCTION,
00145                       RF_LOG,
00146                       RF_SQRT,
00147                       RF_CONST,
00148                       RF_ALL};
00149 
00150 
00151 /** \addtogroup MachineLearning 
00152 **/
00153 //@{
00154 
00155 /**\brief Options object for the random forest
00156  *
00157  * usage:
00158  * RandomForestOptions a =  RandomForestOptions()
00159  *                              .param1(value1)
00160  *                              .param2(value2)
00161  *                              ...
00162  *
00163  * This class only contains options/parameters that are not problem
00164  * dependent. The ProblemSpec class contains methods to set class weights
00165  * if necessary.
00166  *
00167  * Note that the return value of all methods is *this which makes
00168  * concatenating of options as above possible.
00169  */
00170 class RandomForestOptions
00171 {
00172   public:
00173     /**\name sampling options*/
00174     /*\{*/
00175     // look at the member access functions for documentation
00176     double  training_set_proportion_;
00177     int     training_set_size_;
00178     int (*training_set_func_)(int);
00179     RF_OptionTag
00180         training_set_calc_switch_;
00181 
00182     bool    sample_with_replacement_;
00183     RF_OptionTag
00184             stratification_method_;
00185 
00186 
00187     /**\name general random forest options
00188      *
00189      * these usually will be used by most split functors and
00190      * stopping predicates
00191      */
00192     /*\{*/
00193     RF_OptionTag    mtry_switch_;
00194     int     mtry_;
00195     int (*mtry_func_)(int) ;
00196 
00197     bool predict_weighted_; 
00198     int tree_count_;
00199     int min_split_node_size_;
00200     bool prepare_online_learning_;
00201     /*\}*/
00202 
00203     int serialized_size() const
00204     {
00205         return 12;
00206     }
00207     
00208 
00209     bool operator==(RandomForestOptions & rhs) const
00210     {
00211         bool result = true;
00212         #define COMPARE(field) result = result && (this->field == rhs.field); 
00213         COMPARE(training_set_proportion_);
00214         COMPARE(training_set_size_);
00215         COMPARE(training_set_calc_switch_);
00216         COMPARE(sample_with_replacement_);
00217         COMPARE(stratification_method_);
00218         COMPARE(mtry_switch_);
00219         COMPARE(mtry_);
00220         COMPARE(tree_count_);
00221         COMPARE(min_split_node_size_);
00222         COMPARE(predict_weighted_);
00223         #undef COMPARE
00224 
00225         return result;
00226     }
00227     bool operator!=(RandomForestOptions & rhs_) const
00228     {
00229         return !(*this == rhs_);
00230     }
00231     template<class Iter>
00232     void unserialize(Iter const & begin, Iter const & end)
00233     {
00234         Iter iter = begin;
00235         vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 
00236                            "RandomForestOptions::unserialize():"
00237                            "wrong number of parameters");
00238         #define PULL(item_, type_) item_ = type_(*iter); ++iter;
00239         PULL(training_set_proportion_, double);
00240         PULL(training_set_size_, int);
00241         ++iter; //PULL(training_set_func_, double);
00242         PULL(training_set_calc_switch_, (RF_OptionTag)int);
00243         PULL(sample_with_replacement_, 0 != );
00244         PULL(stratification_method_, (RF_OptionTag)int);
00245         PULL(mtry_switch_, (RF_OptionTag)int);
00246         PULL(mtry_, int);
00247         ++iter; //PULL(mtry_func_, double);
00248         PULL(tree_count_, int);
00249         PULL(min_split_node_size_, int);
00250         PULL(predict_weighted_, 0 !=);
00251         #undef PULL
00252     }
00253     template<class Iter>
00254     void serialize(Iter const &  begin, Iter const & end) const
00255     {
00256         Iter iter = begin;
00257         vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 
00258                            "RandomForestOptions::serialize():"
00259                            "wrong number of parameters");
00260         #define PUSH(item_) *iter = double(item_); ++iter;
00261         PUSH(training_set_proportion_);
00262         PUSH(training_set_size_);
00263         if(training_set_func_ != 0)
00264         {
00265             PUSH(1);
00266         }
00267         else
00268         {
00269             PUSH(0);
00270         }
00271         PUSH(training_set_calc_switch_);
00272         PUSH(sample_with_replacement_);
00273         PUSH(stratification_method_);
00274         PUSH(mtry_switch_);
00275         PUSH(mtry_);
00276         if(mtry_func_ != 0)
00277         {
00278             PUSH(1);
00279         }
00280         else
00281         {
00282             PUSH(0);
00283         }
00284         PUSH(tree_count_);
00285         PUSH(min_split_node_size_);
00286         PUSH(predict_weighted_);
00287         #undef PUSH
00288     }
00289     
00290     void make_from_map(std::map<std::string, ArrayVector<double> > & in)
00291     {
00292         typedef MultiArrayShape<2>::type Shp; 
00293         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
00294         #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0); 
00295         PULL(training_set_proportion_,double);
00296         PULL(training_set_size_, int);
00297         PULL(mtry_, int);
00298         PULL(tree_count_, int);
00299         PULL(min_split_node_size_, int);
00300         PULLBOOL(sample_with_replacement_, bool);
00301         PULLBOOL(prepare_online_learning_, bool);
00302         PULLBOOL(predict_weighted_, bool);
00303         
00304         PULL(training_set_calc_switch_, (RF_OptionTag)int);
00305         PULL(stratification_method_, (RF_OptionTag)int);
00306         PULL(mtry_switch_, (RF_OptionTag)int);
00307         
00308         /*don't pull*/
00309         //PULL(mtry_func_!=0, int);
00310         //PULL(training_set_func,int);
00311         #undef PULL
00312         #undef PULLBOOL
00313     }
00314     void make_map(std::map<std::string, ArrayVector<double> > & in) const
00315     {
00316         typedef MultiArrayShape<2>::type Shp; 
00317         #define PUSH(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_)); 
00318         #define PUSHFUNC(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_!=0)); 
00319         PUSH(training_set_proportion_,double);
00320         PUSH(training_set_size_, int);
00321         PUSH(mtry_, int);
00322         PUSH(tree_count_, int);
00323         PUSH(min_split_node_size_, int);
00324         PUSH(sample_with_replacement_, bool);
00325         PUSH(prepare_online_learning_, bool);
00326         PUSH(predict_weighted_, bool);
00327         
00328         PUSH(training_set_calc_switch_, RF_OptionTag);
00329         PUSH(stratification_method_, RF_OptionTag);
00330         PUSH(mtry_switch_, RF_OptionTag);
00331         
00332         PUSHFUNC(mtry_func_, int);
00333         PUSHFUNC(training_set_func_,int);
00334         #undef PUSH
00335         #undef PUSHFUNC
00336     }
00337 
00338 
00339     /**\brief create a RandomForestOptions object with default initialisation.
00340      *
00341      * look at the other member functions for more information on default
00342      * values
00343      */
00344     RandomForestOptions()
00345     :
00346         training_set_proportion_(1.0),
00347         training_set_size_(0),
00348         training_set_func_(0),
00349         training_set_calc_switch_(RF_PROPORTIONAL),
00350         sample_with_replacement_(true),
00351         stratification_method_(RF_NONE),
00352         mtry_switch_(RF_SQRT),
00353         mtry_(0),
00354         mtry_func_(0),
00355         predict_weighted_(false),
00356         tree_count_(256),
00357         min_split_node_size_(1),
00358         prepare_online_learning_(false)
00359     {}
00360 
00361     /**\brief specify stratification strategy
00362      *
00363      * default: RF_NONE
00364      * possible values: RF_EQUAL, RF_PROPORTIONAL,
00365      *                  RF_EXTERNAL, RF_NONE
00366      * RF_EQUAL:        get equal amount of samples per class.
00367      * RF_PROPORTIONAL: sample proportional to fraction of class samples
00368      *                  in population
00369      * RF_EXTERNAL:     strata_weights_ field of the ProblemSpec_t object
00370      *                  has been set externally. (defunct)
00371      */
00372     RandomForestOptions & use_stratification(RF_OptionTag in)
00373     {
00374         vigra_precondition(in == RF_EQUAL ||
00375                            in == RF_PROPORTIONAL ||
00376                            in == RF_EXTERNAL ||
00377                            in == RF_NONE,
00378                            "RandomForestOptions::use_stratification()"
00379                            "input must be RF_EQUAL, RF_PROPORTIONAL,"
00380                            "RF_EXTERNAL or RF_NONE");
00381         stratification_method_ = in;
00382         return *this;
00383     }
00384 
00385     RandomForestOptions & prepare_online_learning(bool in)
00386     {
00387         prepare_online_learning_=in;
00388         return *this;
00389     }
00390 
00391     /**\brief sample from training population with or without replacement?
00392      *
00393      * <br> Default: true
00394      */
00395     RandomForestOptions & sample_with_replacement(bool in)
00396     {
00397         sample_with_replacement_ = in;
00398         return *this;
00399     }
00400 
00401     /**\brief  specify the fraction of the total number of samples 
00402      * used per tree for learning. 
00403      *
00404      * This value should be in [0.0 1.0] if sampling without
00405      * replacement has been specified.
00406      *
00407      * <br> default : 1.0
00408      */
00409     RandomForestOptions & samples_per_tree(double in)
00410     {
00411         training_set_proportion_ = in;
00412         training_set_calc_switch_ = RF_PROPORTIONAL;
00413         return *this;
00414     }
00415 
00416     /**\brief directly specify the number of samples per tree
00417      */
00418     RandomForestOptions & samples_per_tree(int in)
00419     {
00420         training_set_size_ = in;
00421         training_set_calc_switch_ = RF_CONST;
00422         return *this;
00423     }
00424 
00425     /**\brief use external function to calculate the number of samples each
00426      *        tree should be learnt with.
00427      *
00428      * \param in function pointer that takes the number of rows in the
00429      *           learning data and outputs the number samples per tree.
00430      */
00431     RandomForestOptions & samples_per_tree(int (*in)(int))
00432     {
00433         training_set_func_ = in;
00434         training_set_calc_switch_ = RF_FUNCTION;
00435         return *this;
00436     }
00437     
00438     /**\brief weight each tree with number of samples in that node
00439      */
00440     RandomForestOptions & predict_weighted()
00441     {
00442         predict_weighted_ = true;
00443         return *this;
00444     }
00445 
00446     /**\brief use built in mapping to calculate mtry
00447      *
00448      * Use one of the built in mappings to calculate mtry from the number
00449      * of columns in the input feature data.
00450      * \param in possible values: RF_LOG, RF_SQRT or RF_ALL
00451      *           <br> default: RF_SQRT.
00452      */
00453     RandomForestOptions & features_per_node(RF_OptionTag in)
00454     {
00455         vigra_precondition(in == RF_LOG ||
00456                            in == RF_SQRT||
00457                            in == RF_ALL,
00458                            "RandomForestOptions()::features_per_node():"
00459                            "input must be of type RF_LOG or RF_SQRT");
00460         mtry_switch_ = in;
00461         return *this;
00462     }
00463 
00464     /**\brief Set mtry to a constant value
00465      *
00466      * mtry is the number of columns/variates/variables randomly choosen
00467      * to select the best split from.
00468      *
00469      */
00470     RandomForestOptions & features_per_node(int in)
00471     {
00472         mtry_ = in;
00473         mtry_switch_ = RF_CONST;
00474         return *this;
00475     }
00476 
00477     /**\brief use a external function to calculate mtry
00478      *
00479      * \param in function pointer that takes int (number of columns
00480      *           of the and outputs int (mtry)
00481      */
00482     RandomForestOptions & features_per_node(int(*in)(int))
00483     {
00484         mtry_func_ = in;
00485         mtry_switch_ = RF_FUNCTION;
00486         return *this;
00487     }
00488 
00489     /** How many trees to create?
00490      *
00491      * <br> Default: 255.
00492      */
00493     RandomForestOptions & tree_count(int in)
00494     {
00495         tree_count_ = in;
00496         return *this;
00497     }
00498 
00499     /**\brief Number of examples required for a node to be split.
00500      *
00501      *  When the number of examples in a node is below this number,
00502      *  the node is not split even if class separation is not yet perfect.
00503      *  Instead, the node returns the proportion of each class
00504      *  (among the remaining examples) during the prediction phase.
00505      *  <br> Default: 1 (complete growing)
00506      */
00507     RandomForestOptions & min_split_node_size(int in)
00508     {
00509         min_split_node_size_ = in;
00510         return *this;
00511     }
00512 };
00513 
00514 
00515 /** \brief problem types 
00516  */
00517 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
00518 
00519 
00520 /** \brief problem specification class for the random forest.
00521  *
00522  * This class contains all the problem specific parameters the random
00523  * forest needs for learning. Specification of an instance of this class
00524  * is optional as all necessary fields will be computed prior to learning
00525  * if not specified.
00526  *
00527  * if needed usage is similar to that of RandomForestOptions
00528  */
00529 
00530 template<class LabelType = double>
00531 class ProblemSpec
00532 {
00533 
00534 
00535 public:
00536 
00537     /** \brief  problem class
00538      */
00539 
00540     typedef LabelType       Label_t;
00541     ArrayVector<Label_t>    classes;
00542 
00543     int                     column_count_;    // number of features
00544     int                     class_count_;     // number of classes
00545     int                     row_count_;       // number of samples
00546 
00547     int                     actual_mtry_;     // mtry used in training
00548     int                     actual_msample_;  // number if in-bag samples per tree
00549 
00550     Problem_t               problem_type_;    // classification or regression
00551     
00552     int used_;                                // this ProblemSpec is valid
00553     ArrayVector<double>     class_weights_;   // if classes have different importance
00554     int                     is_weighted_;     // class_weights_ are used
00555     double                  precision_;       // termination criterion for regression loss
00556     
00557         
00558     template<class T> 
00559     void to_classlabel(int index, T & out) const
00560     {
00561         out = T(classes[index]);
00562     }
00563     template<class T> 
00564     int to_classIndex(T index) const
00565     {
00566         return std::find(classes.begin(), classes.end(), index) - classes.begin();
00567     }
00568 
00569     #define EQUALS(field) field(rhs.field)
00570     ProblemSpec(ProblemSpec const & rhs)
00571     : 
00572         EQUALS(column_count_),
00573         EQUALS(class_count_),
00574         EQUALS(row_count_),
00575         EQUALS(actual_mtry_),
00576         EQUALS(actual_msample_),
00577         EQUALS(problem_type_),
00578         EQUALS(used_),
00579         EQUALS(class_weights_),
00580         EQUALS(is_weighted_),
00581         EQUALS(precision_)
00582     {
00583         std::back_insert_iterator<ArrayVector<Label_t> >
00584                         iter(classes);
00585         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00586     }
00587     #undef EQUALS
00588     #define EQUALS(field) field(rhs.field)
00589     template<class T>
00590     ProblemSpec(ProblemSpec<T> const & rhs)
00591     : 
00592         EQUALS(column_count_),
00593         EQUALS(class_count_),
00594         EQUALS(row_count_),
00595         EQUALS(actual_mtry_),
00596         EQUALS(actual_msample_),
00597         EQUALS(problem_type_),
00598         EQUALS(used_),
00599         EQUALS(class_weights_),
00600         EQUALS(is_weighted_),
00601         EQUALS(precision_)
00602     {
00603         std::back_insert_iterator<ArrayVector<Label_t> >
00604                         iter(classes);
00605         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00606     }
00607     #undef EQUALS
00608 
00609     // for some reason the function below does not match
00610     // the default copy constructor
00611     #define EQUALS(field) (this->field = rhs.field);
00612     ProblemSpec & operator=(ProblemSpec const & rhs)
00613     {
00614         EQUALS(column_count_);
00615         EQUALS(class_count_);
00616         EQUALS(row_count_);
00617         EQUALS(actual_mtry_);
00618         EQUALS(actual_msample_);
00619         EQUALS(problem_type_);
00620         EQUALS(used_);
00621         EQUALS(is_weighted_);
00622         EQUALS(precision_);
00623         class_weights_.clear();
00624         std::back_insert_iterator<ArrayVector<double> >
00625                         iter2(class_weights_);
00626         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 
00627         classes.clear();
00628         std::back_insert_iterator<ArrayVector<Label_t> >
00629                         iter(classes);
00630         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00631         return *this;
00632     }
00633 
00634     template<class T>
00635     ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
00636     {
00637         EQUALS(column_count_);
00638         EQUALS(class_count_);
00639         EQUALS(row_count_);
00640         EQUALS(actual_mtry_);
00641         EQUALS(actual_msample_);
00642         EQUALS(problem_type_);
00643         EQUALS(used_);
00644         EQUALS(is_weighted_);
00645         EQUALS(precision_);
00646         class_weights_.clear();
00647         std::back_insert_iterator<ArrayVector<double> >
00648                         iter2(class_weights_);
00649         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 
00650         classes.clear();
00651         std::back_insert_iterator<ArrayVector<Label_t> >
00652                         iter(classes);
00653         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00654         return *this;
00655     }
00656     #undef EQUALS
00657 
00658     template<class T>
00659     bool operator==(ProblemSpec<T> const & rhs)
00660     {
00661         bool result = true;
00662         #define COMPARE(field) result = result && (this->field == rhs.field);
00663         COMPARE(column_count_);
00664         COMPARE(class_count_);
00665         COMPARE(row_count_);
00666         COMPARE(actual_mtry_);
00667         COMPARE(actual_msample_);
00668         COMPARE(problem_type_);
00669         COMPARE(is_weighted_);
00670         COMPARE(precision_);
00671         COMPARE(used_);
00672         COMPARE(class_weights_);
00673         COMPARE(classes);
00674         #undef COMPARE
00675         return result;
00676     }
00677 
00678     bool operator!=(ProblemSpec & rhs)
00679     {
00680         return !(*this == rhs);
00681     }
00682 
00683 
00684     size_t serialized_size() const
00685     {
00686         return 9 + class_count_ *int(is_weighted_+1);
00687     }
00688 
00689 
00690     template<class Iter>
00691     void unserialize(Iter const & begin, Iter const & end)
00692     {
00693         Iter iter = begin;
00694         vigra_precondition(end - begin >= 9, 
00695                            "ProblemSpec::unserialize():"
00696                            "wrong number of parameters");
00697         #define PULL(item_, type_) item_ = type_(*iter); ++iter;
00698         PULL(column_count_,int);
00699         PULL(class_count_, int);
00700 
00701         vigra_precondition(end - begin >= 9 + class_count_, 
00702                            "ProblemSpec::unserialize(): 1");
00703         PULL(row_count_, int);
00704         PULL(actual_mtry_,int);
00705         PULL(actual_msample_, int);
00706         PULL(problem_type_, Problem_t);
00707         PULL(is_weighted_, int);
00708         PULL(used_, int);
00709         PULL(precision_, double);
00710         if(is_weighted_)
00711         {
00712             vigra_precondition(end - begin == 9 + 2*class_count_, 
00713                                "ProblemSpec::unserialize(): 2");
00714             class_weights_.insert(class_weights_.end(),
00715                                   iter, 
00716                                   iter + class_count_);
00717             iter += class_count_; 
00718         }
00719         classes.insert(classes.end(), iter, end);
00720         #undef PULL
00721     }
00722 
00723 
00724     template<class Iter>
00725     void serialize(Iter const & begin, Iter const & end) const
00726     {
00727         Iter iter = begin;
00728         vigra_precondition(end - begin == serialized_size(), 
00729                            "RandomForestOptions::serialize():"
00730                            "wrong number of parameters");
00731         #define PUSH(item_) *iter = double(item_); ++iter;
00732         PUSH(column_count_);
00733         PUSH(class_count_)
00734         PUSH(row_count_);
00735         PUSH(actual_mtry_);
00736         PUSH(actual_msample_);
00737         PUSH(problem_type_);
00738         PUSH(is_weighted_);
00739         PUSH(used_);
00740         PUSH(precision_);
00741         if(is_weighted_)
00742         {
00743             std::copy(class_weights_.begin(),
00744                       class_weights_.end(),
00745                       iter);
00746             iter += class_count_; 
00747         }
00748         std::copy(classes.begin(),
00749                   classes.end(),
00750                   iter);
00751         #undef PUSH
00752     }
00753 
00754     void make_from_map(std::map<std::string, ArrayVector<double> > & in)
00755     {
00756         typedef MultiArrayShape<2>::type Shp; 
00757         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
00758         PULL(column_count_,int);
00759         PULL(class_count_, int);
00760         PULL(row_count_, int);
00761         PULL(actual_mtry_,int);
00762         PULL(actual_msample_, int);
00763         PULL(problem_type_, (Problem_t)int);
00764         PULL(is_weighted_, int);
00765         PULL(used_, int);
00766         PULL(precision_, double);
00767         class_weights_ = in["class_weights_"];
00768         #undef PUSH
00769     }
00770     void make_map(std::map<std::string, ArrayVector<double> > & in) const
00771     {
00772         typedef MultiArrayShape<2>::type Shp; 
00773         #define PUSH(item_) in[#item_] = ArrayVector<double>(1, double(item_)); 
00774         PUSH(column_count_);
00775         PUSH(class_count_)
00776         PUSH(row_count_);
00777         PUSH(actual_mtry_);
00778         PUSH(actual_msample_);
00779         PUSH(problem_type_);
00780         PUSH(is_weighted_);
00781         PUSH(used_);
00782         PUSH(precision_);
00783         in["class_weights_"] = class_weights_;
00784         #undef PUSH
00785     }
00786     
00787     /**\brief set default values (-> values not set)
00788      */
00789     ProblemSpec()
00790     :   column_count_(0),
00791         class_count_(0),
00792         row_count_(0),
00793         actual_mtry_(0),
00794         actual_msample_(0),
00795         problem_type_(CHECKLATER),
00796         used_(false),
00797         is_weighted_(false),
00798         precision_(0.0)
00799     {}
00800 
00801 
00802     ProblemSpec & column_count(int in)
00803     {
00804         column_count_ = in;
00805         return *this;
00806     }
00807 
00808     /**\brief supply with class labels -
00809      * 
00810      * the preprocessor will not calculate the labels needed in this case.
00811      */
00812     template<class C_Iter>
00813     ProblemSpec & classes_(C_Iter begin, C_Iter end)
00814     {
00815         int size = end-begin;
00816         for(int k=0; k<size; ++k, ++begin)
00817             classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
00818         class_count_ = size;
00819         return *this;
00820     }
00821 
00822     /** \brief supply with class weights  -
00823      *
00824      * this is the only case where you would really have to 
00825      * create a ProblemSpec object.
00826      */
00827     template<class W_Iter>
00828     ProblemSpec & class_weights(W_Iter begin, W_Iter end)
00829     {
00830         class_weights_.insert(class_weights_.end(), begin, end);
00831         is_weighted_ = true;
00832         return *this;
00833     }
00834 
00835 
00836 
00837     void clear()
00838     {
00839         used_ = false; 
00840         classes.clear();
00841         class_weights_.clear();
00842         column_count_ = 0 ;
00843         class_count_ = 0;
00844         actual_mtry_ = 0;
00845         actual_msample_ = 0;
00846         problem_type_ = CHECKLATER;
00847         is_weighted_ = false;
00848         precision_   = 0.0;
00849 
00850     }
00851 
00852     bool used() const
00853     {
00854         return used_ != 0;
00855     }
00856 };
00857 
00858 
00859 //@}
00860 
00861 
00862 
00863 /**\brief Standard early stopping criterion
00864  *
00865  * Stop if region.size() < min_split_node_size_;
00866  */
00867 class EarlyStoppStd
00868 {
00869     public:
00870     int min_split_node_size_;
00871 
00872     template<class Opt>
00873     EarlyStoppStd(Opt opt)
00874     :   min_split_node_size_(opt.min_split_node_size_)
00875     {}
00876 
00877     template<class T>
00878     void set_external_parameters(ProblemSpec<T>const  &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
00879     {}
00880 
00881     template<class Region>
00882     bool operator()(Region& region)
00883     {
00884         return region.size() < min_split_node_size_;
00885     }
00886 
00887     template<class WeightIter, class T, class C>
00888     bool after_prediction(WeightIter,  int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
00889     {
00890         return false; 
00891     }
00892 };
00893 
00894 
00895 } // namespace vigra
00896 
00897 #endif //VIGRA_RF_COMMON_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (3 Dec 2010)