contrib/mul/clsfy/clsfy_binary_tree_builder.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree_builder.h
00002 #ifndef clsfy_binary_tree_builder_h_
00003 #define clsfy_binary_tree_builder_h_
00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00005 #pragma interface
00006 #endif
00007 //:
00008 // \file
00009 // \brief Build a binary tree classifier
00010 // \author Martin Roberts
00011 
00012 #include <clsfy/clsfy_builder_base.h>
00013 #include <clsfy/clsfy_binary_tree.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_set.h>
00016 #include <vcl_string.h>
00017 #include <vcl_iosfwd.h>
00018 #include <mbl/mbl_data_wrapper.h>
00019 #include <vnl/vnl_vector.h>
00020 #include <vnl/vnl_random.h>
00021 
00022 
00023 class clsfy_binary_tree_bnode :public  clsfy_binary_tree_node
00024 {
00025     //Similar to classifiers tree node but the builder also needs
00026     //to keep track of relevant data subsets at each node
00027     vcl_set<unsigned> subIndicesL;
00028     vcl_set<unsigned> subIndicesR;
00029  
00030     
00031   clsfy_binary_tree_bnode(clsfy_binary_tree_node* parent,
00032                           const clsfy_binary_tree_op& op):
00033     clsfy_binary_tree_node(parent,op) {}
00034 
00035     virtual clsfy_binary_tree_node* create_child(const clsfy_binary_tree_op& op);
00036 
00037     //Note the owning classifier removes the tree - beware as once deleted its children
00038     //may be inaccessible for deletion
00039     virtual ~clsfy_binary_tree_bnode();
00040     
00041     friend class clsfy_binary_tree_builder;
00042 };
00043 
00044 
00045 //: Builds clsfy_binary_tree classifiers
00046 // Keep finding the variable split that gives the least min_error for
00047 // a binary threshold. Divide up the dataset by that and keep recursively
00048 // building binary threshold classifiers in a tree structure till either
00049 // Max depth level reached, or a node is pure, or node's data <min_nide_size
00050 
00051 class clsfy_binary_tree_builder : public clsfy_builder_base
00052 {
00053     //: The max depth of any leaf node in the tree
00054     //If negative no max is applied, and all final leaf nodes are pure
00055     //(i.e. single class)
00056     int max_depth_;
00057 
00058     //: Minimum number of points associated with any node
00059     // If negative this is ignored, otherwise if a split would produce a child
00060     // node less than this, then the split does not occur and the branch is
00061     // terminated
00062     int min_node_size_;
00063     
00064     //: Set this for random forest behaviour
00065     //At each split the selection is only from a random subset of this size
00066     //If negative (default) it is ignored and all are used
00067     int nbranch_params_;
00068 
00069     //: Work space for randomising params (NB not thread safe)
00070     mutable vcl_vector<unsigned > base_indices_;
00071 
00072 
00073   public:
00074     // Dflt ctor
00075     clsfy_binary_tree_builder();
00076 
00077     //: Create empty model
00078     // Caller is responsible for deletion
00079     virtual clsfy_classifier_base* new_classifier() const;
00080 
00081     //: Build classifier from data
00082     // return the mean error over the training set.
00083     virtual double build(clsfy_classifier_base& classifier,
00084                          mbl_data_wrapper<vnl_vector<double> >& inputs,
00085                          unsigned nClasses,
00086                          const vcl_vector<unsigned> &outputs) const;
00087 
00088     //: Name of the class
00089     virtual vcl_string is_a() const;
00090 
00091     //: Name of the class
00092     virtual bool is_class(vcl_string const& s) const;
00093 
00094     //: IO Version number
00095     short version_no() const;
00096 
00097     //: Create a copy on the heap and return base class pointer
00098     virtual clsfy_builder_base* clone() const;
00099 
00100     //: Print class to os
00101     virtual void print_summary(vcl_ostream& os) const;
00102 
00103     //: Save class to binary file stream
00104     virtual void b_write(vsl_b_ostream& bfs) const;
00105 
00106     //: Load class from binary file stream
00107     virtual void b_read(vsl_b_istream& bfs);
00108 
00109     //: The max tree depth (default -1 means no max set )
00110     int max_depth() const {return max_depth_;};
00111 
00112     //: Set the number of nearest neighbours to look for.
00113     // If not see default is high value to force continuation till
00114     // all final leaf nodes are pure (i.e. single class)
00115     // If set negative the value is ignored
00116     void set_max_depth(int max_depth) {max_depth_=max_depth;}
00117 
00118     int min_node_size() const {return min_node_size_;}
00119 
00120     //: Set minimum number of points associated with any node
00121     // If negative this is ignored, otherwise if a split would produce a child
00122     // node less than this, then the split does not occur and the branch is
00123     // terminated
00124     void set_min_node_size(int min_node_size) {min_node_size_=min_node_size;}
00125     
00126     //: Set this for random forest behaviour
00127     //At each split the selection is only from a random subset of this size 
00128     //If negative then it is ignored
00129     void set_nbranch_params(int nbranch_params) {nbranch_params_ = nbranch_params;}
00130 
00131     //: set whether the build calculatates a test error over the input training set
00132     //Default is on, but this can be turned off e.g. for a random forest of
00133     //many child trees
00134     void set_calc_test_error(bool on) {calc_test_error_=on;}
00135     
00136     //: Seed the sample used to select branching parameter subsets
00137     void seed_sampler(unsigned long seed);
00138   protected:
00139     //: Randomly select  the ndimsUsed dimensions for current branch
00140     // Return indices of selected parameters
00141     // Best of these is then chosen as the branch
00142     virtual void randomise_parameters(unsigned ndimsUsed,
00143                                       vcl_vector<unsigned  >& param_indices) const;
00144 
00145     mutable  vnl_random random_sampler_;
00146 
00147     
00148   private:
00149     void build_children(
00150         const vcl_vector<vnl_vector<double> >& vin,
00151         const vcl_vector<unsigned>& outputs,
00152         clsfy_binary_tree_bnode* parent,bool left) const;
00153     
00154     void copy_children(clsfy_binary_tree_bnode* pBuilderNode,clsfy_binary_tree_node* pNode) const;
00155 
00156     void set_node_prob(clsfy_binary_tree_node* pNode,
00157                        clsfy_binary_tree_bnode* pBuilderNode) const ;
00158 
00159     void build_a_node(
00160         const vcl_vector<vnl_vector<double> >& vin,
00161         const vcl_vector<unsigned>& outputs,
00162         const vcl_set<unsigned >& subIndices,
00163         clsfy_binary_tree_bnode* pNode) const;
00164     
00165     bool isNodePure(const vcl_set<unsigned >& subIndices,
00166                     const vcl_vector<unsigned>& outputs) const;
00167     
00168     void add_terminator(
00169         const vcl_vector<vnl_vector<double> >& vin,
00170         const vcl_vector<unsigned>& outputs,
00171         clsfy_binary_tree_bnode* parent,
00172         bool left, bool pure) const;
00173 
00174     bool calc_test_error_;
00175     
00176 };
00177 
00178 
00179 #endif // clsfy_binary_tree_builder_h_