contrib/mul/clsfy/clsfy_random_forest_builder.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_random_forest_builder.h
00002 #ifndef clsfy_random_forest_builder_h_
00003 #define clsfy_random_forest_builder_h_
00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00005 #pragma interface
00006 #endif
00007 //:
00008 // \file
00009 // \brief Build a random forest classifier
00010 // \author Martin Roberts
00011 
00012 #include <clsfy/clsfy_builder_base.h>
00013 #include <clsfy/clsfy_random_forest.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_set.h>
00016 #include <vcl_string.h>
00017 #include <vcl_iosfwd.h>
00018 #include <vnl/vnl_vector.h>
00019 #include <vnl/vnl_random.h>
00020 
00021 #include <mbl/mbl_data_wrapper.h>
00022 
00023 
00024 //: Builds clsfy_random_forest classifiers
00025 class clsfy_random_forest_builder : public clsfy_builder_base
00026 {
00027   public:
00028     // Dflt ctor
00029     clsfy_random_forest_builder();
00030 
00031     clsfy_random_forest_builder(unsigned ntrees,
00032                                 int max_depth=-1,int min_node_size=-1);
00033     virtual ~clsfy_random_forest_builder();
00034 
00035     //: Create empty model
00036     // Caller is responsible for deletion
00037     virtual clsfy_classifier_base* new_classifier() const;
00038 
00039     //: Build classifier from data
00040     // return the mean error over the training set.
00041     virtual double build(clsfy_classifier_base& classifier,
00042                          mbl_data_wrapper<vnl_vector<double> >& inputs,
00043                          unsigned nClasses,
00044                          const vcl_vector<unsigned> &outputs) const;
00045 
00046     //: Name of the class
00047     virtual vcl_string is_a() const;
00048 
00049     //: Name of the class
00050     virtual bool is_class(vcl_string const& s) const;
00051 
00052     //: IO Version number
00053     short version_no() const;
00054 
00055     //: Create a copy on the heap and return base class pointer
00056     virtual clsfy_builder_base* clone() const;
00057 
00058     //: Print class to os
00059     virtual void print_summary(vcl_ostream& os) const;
00060 
00061     //: Save class to binary file stream
00062     virtual void b_write(vsl_b_ostream& bfs) const;
00063 
00064     //: Load class from binary file stream
00065     virtual void b_read(vsl_b_istream& bfs);
00066 
00067     //: The max tree depth (default -1 means no max set )
00068     int max_depth() const {return max_depth_;};
00069 
00070     //: Set the number of nearest neighbours to look for.
00071     // If not see default is high value to force continuation till
00072     // all final leaf nodes are pure (i.e. single class)
00073     // If set negative the value is ignored
00074     void set_max_depth(int max_depth) {max_depth_=max_depth;}
00075 
00076     int min_node_size() const {return min_node_size_;}
00077 
00078     //: Set minimum number of points associated with any node
00079     // If negative this is ignored, otherwise if a split would produce a child
00080     // node less than this, then the split does not occur and the branch is
00081     // terminated
00082     void set_min_node_size(int min_node_size) {min_node_size_=min_node_size;}
00083 
00084 
00085     //: set number of trees in forest
00086     //Note this must be set before calling build
00087     //Default is 100
00088     void set_ntrees(unsigned ntrees) {ntrees_=ntrees;}
00089 
00090     unsigned ntrees() const {return ntrees_;}
00091 
00092     virtual void seed_sampler(unsigned long seed);
00093 
00094     //: set whether the build calculatates a test error over the input training set
00095     //Default is on, but this can be turned off
00096     //e.g. for a parallel build of many partial random forests of
00097     //which can be later merged
00098     void set_calc_test_error(bool on) {calc_test_error_=on;}
00099 
00100     //: Save a pointer to storage for out of bag indices
00101     void set_oob_indices( vcl_vector<vcl_vector<unsigned > >* poobIndices)
00102     {poob_indices_=poobIndices;}
00103 
00104   protected:
00105     //:Pick the number of parameters that the tree builder branches on
00106     //Default uses sqrt of ndims
00107     virtual unsigned select_nbranch_params(unsigned ndims) const;
00108 
00109     //: Pick a random data subset (with replacement)
00110     virtual void select_data(vcl_vector<vnl_vector<double> >& inputs,
00111                              const vcl_vector<unsigned> &outputs,
00112                              vcl_vector<vnl_vector<double> >& bootstrapped_inputs,
00113                              vcl_vector<unsigned> & bootstrapped_outputs) const;
00114 
00115     virtual unsigned long get_tree_builder_seed() const;
00116 
00117     //: Number of trees
00118     unsigned ntrees_;
00119     //: The max depth of any child tree
00120     //If negative no max is applied, and all final leaf nodes are pure
00121     //(i.e. single class)
00122     int max_depth_;
00123 
00124 
00125     //: Minimum number of points associated with any node
00126     // If negative this is ignored, otherwise if a split would produce a child
00127     // node less than this, then the split does not occur and the branch is
00128     // terminated
00129     int min_node_size_;
00130 
00131     //: Uniform sampler on 0,1 (for bootsrapping)
00132     mutable  vnl_random random_sampler_;
00133 
00134     //: Pointer to storeage of point indices for each bootstrapped tree
00135     // Can be used for out of bag estimates
00136     // Saves for tree i the indices of all points used in its training
00137     //Note the storeage is supplied from outside this class, as this is a kind of bolt-on
00138     vcl_vector<vcl_vector<unsigned > >* poob_indices_;
00139   private:
00140     //: Does the builder calculate the error on the training set?
00141     bool calc_test_error_;
00142 };
00143 
00144 
00145 #endif // clsfy_random_forest_builder_h_