contrib/mul/clsfy/clsfy_random_forest_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_random_forest_builder.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 // \brief Implement a random_forest classifier builder
00008 // \author Martin Roberts
00009 
00010 #include "clsfy_random_forest_builder.h"
00011 #include <vxl_config.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_string.h>
00014 #include <vcl_algorithm.h>
00015 #include <vcl_numeric.h>
00016 #include <vcl_iterator.h>
00017 #include <vcl_cassert.h>
00018 #include <vsl/vsl_binary_loader.h>
00019 #include <vnl/vnl_math.h>
00020 #include <mbl/mbl_stl.h>
00021 #include <mbl/mbl_data_array_wrapper.h>
00022 #include <clsfy/clsfy_binary_tree_builder.h>
00023 #include "clsfy_random_forest.h"
00024 
00025 //=======================================================================
00026 
00027 clsfy_random_forest_builder::clsfy_random_forest_builder():
00028     max_depth_(-1),min_node_size_(-1),
00029     ntrees_(100),
00030     calc_test_error_(true),
00031     poob_indices_(0)
00032 {
00033     unsigned long default_seed=123654987;
00034     seed_sampler(default_seed);
00035 }
00036 
00037 clsfy_random_forest_builder::clsfy_random_forest_builder(unsigned ntrees,
00038                                                          int max_depth,
00039                                                          int min_node_size):
00040     max_depth_(max_depth),min_node_size_(min_node_size),
00041     ntrees_(ntrees),
00042     calc_test_error_(true),
00043     poob_indices_(0)
00044 {
00045     unsigned long default_seed=123654987;
00046     seed_sampler(default_seed);
00047 }
00048 
00049 clsfy_random_forest_builder::~clsfy_random_forest_builder()
00050 {
00051 }
00052 //=======================================================================
00053 
00054 short clsfy_random_forest_builder::version_no() const
00055 {
00056     return 1;
00057 }
00058 
00059 //=======================================================================
00060 
00061 vcl_string clsfy_random_forest_builder::is_a() const
00062 {
00063     return vcl_string("clsfy_random_forest_builder");
00064 }
00065 
00066 //=======================================================================
00067 
00068 bool clsfy_random_forest_builder::is_class(vcl_string const& s) const
00069 {
00070     return s == clsfy_random_forest_builder::is_a() || clsfy_builder_base::is_class(s);
00071 }
00072 
00073 //=======================================================================
00074 
00075 clsfy_builder_base* clsfy_random_forest_builder::clone() const
00076 {
00077     return new clsfy_random_forest_builder(*this);
00078 }
00079 
00080 //=======================================================================
00081 
00082 void clsfy_random_forest_builder::print_summary(vcl_ostream& os) const
00083 {
00084     os << "Num trees = "<<ntrees_<<"\tmax_depth = " << max_depth_;
00085 }
00086 
00087 //=======================================================================
00088 
00089 void clsfy_random_forest_builder::b_write(vsl_b_ostream& bfs) const
00090 {
00091     vsl_b_write(bfs, version_no());
00092     vsl_b_write(bfs, ntrees_);
00093     vsl_b_write(bfs, max_depth_);
00094     vsl_b_write(bfs, min_node_size_);
00095     vsl_b_write(bfs,calc_test_error_);
00096     vcl_cerr << "clsfy_random_forest_builder::b_write() NYI\n";
00097 }
00098 
00099 //=======================================================================
00100 
00101 void clsfy_random_forest_builder::b_read(vsl_b_istream& bfs)
00102 {
00103     if (!bfs) return;
00104 
00105     short version;
00106     vsl_b_read(bfs,version);
00107     switch (version)
00108     {
00109         case (1):
00110             vsl_b_read(bfs, ntrees_);
00111             vsl_b_read(bfs, max_depth_);
00112             vsl_b_read(bfs, min_node_size_);
00113             vsl_b_read(bfs,calc_test_error_);
00114             break;
00115         default:
00116             vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_random_forest_builder&)\n"
00117                      << "           Unknown version number "<< version << "\n";
00118             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00119     }
00120 }
00121 
00122 //=======================================================================
00123 
00124 //: Build model from data
00125 // return the mean error over the training set.
00126 // For many classifiers, you may use nClasses==1 to
00127 // indicate a binary classifier
00128 double clsfy_random_forest_builder::build(clsfy_classifier_base& classifier,
00129                                         mbl_data_wrapper<vnl_vector<double> >& inputs,
00130                                         unsigned nClasses,
00131                                         const vcl_vector<unsigned> &outputs) const
00132 {
00133     assert(classifier.is_class("clsfy_random_forest")); // equiv to dynamic_cast<> != 0
00134     assert(inputs.size()==outputs.size());
00135     assert(nClasses=1);
00136     
00137 
00138     clsfy_random_forest &random_forest = static_cast<clsfy_random_forest&>(classifier);
00139     unsigned npoints=inputs.size();
00140     vcl_vector<vnl_vector<double> > vin(npoints);
00141 
00142     inputs.reset();
00143     unsigned i=0;
00144     do
00145     {
00146         vin[i++] = inputs.current();
00147     } while (inputs.next());
00148 
00149     assert(i==inputs.size());
00150 
00151     unsigned ndims=vin[0].size();
00152     int nbranch_params=select_nbranch_params(ndims);
00153 
00154     //Start with all parameter indices
00155     vcl_cout<<"npoints= "<<npoints<<"\tndims= "<<ndims<<vcl_endl;
00156     vcl_vector<unsigned> indices(ndims,0);
00157 
00158     mbl_stl_increments(indices.begin(),indices.end(),0);
00159 
00160     //Clean any old trees
00161     random_forest.prune();
00162 
00163     if(poob_indices_)
00164     {
00165         poob_indices_->clear();
00166         poob_indices_->reserve(ntrees_);
00167     }
00168 
00169     
00170     vcl_vector<vnl_vector<double> > bootstrapped_inputs;
00171     vcl_vector<unsigned  > bootstrapped_outputs;
00172 
00173     for(i=0;i<ntrees_;++i)
00174     {
00175         if(i %10 == 0)
00176             vcl_cout<<"Building tree "<<i<<vcl_endl;
00177 
00178         select_data(vin,outputs,bootstrapped_inputs,bootstrapped_outputs);
00179 
00180         clsfy_binary_tree_builder builder;
00181         builder.set_calc_test_error(false);
00182 
00183         clsfy_classifier_base* pBaseClassifier=builder.new_classifier();
00184         clsfy_binary_tree* pTreeClassifier=dynamic_cast<clsfy_binary_tree*>(pBaseClassifier);
00185         assert(pTreeClassifier);
00186         builder.set_nbranch_params(nbranch_params);
00187 
00188         unsigned long seed=get_tree_builder_seed();
00189 //        vcl_cout<<"The seed is "<<seed<<vcl_endl;
00190         builder.seed_sampler(seed);
00191         
00192         builder.set_max_depth(max_depth_);
00193         builder.set_min_node_size(min_node_size_);
00194         mbl_data_array_wrapper<vnl_vector<double> > bootstrapped_inputs_mbl(bootstrapped_inputs);
00195 
00196         builder.build(*pTreeClassifier,
00197                        bootstrapped_inputs_mbl,
00198                       1,
00199                       bootstrapped_outputs);
00200 
00201         mbl_cloneable_ptr<clsfy_classifier_base> treeClassifier(pTreeClassifier);
00202         random_forest.trees_.push_back(treeClassifier);
00203     }
00204 
00205     if(calc_test_error_)
00206         return clsfy_test_error(classifier, inputs, outputs);
00207     else
00208         return 0.0;
00209 
00210 
00211 }
00212 //=======================================================================
00213 //: Create empty classifier
00214 // Caller is responsible for deletion
00215 clsfy_classifier_base* clsfy_random_forest_builder::new_classifier() const
00216 {
00217     return new clsfy_random_forest();
00218 }
00219 
00220 
00221 
00222 void clsfy_random_forest_builder::select_data(vcl_vector<vnl_vector<double> >& inputs,
00223                                               const vcl_vector<unsigned> &outputs,
00224                                               vcl_vector<vnl_vector<double> >& bootstrapped_inputs,
00225                                               vcl_vector<unsigned> & bootstrapped_outputs) const
00226 {
00227     unsigned npoints=inputs.size();
00228     bootstrapped_inputs.resize(npoints);
00229     bootstrapped_outputs.resize(npoints);
00230     unsigned ndims=  inputs.front().size();
00231     double dn=double(npoints);
00232     if(poob_indices_)
00233     {
00234         poob_indices_->push_back(vcl_vector<unsigned>());
00235         poob_indices_->back().reserve(npoints);
00236     }
00237     for(unsigned i=0;i<npoints;++i)
00238     {
00239         bootstrapped_inputs[i].set_size(ndims);
00240         unsigned index=random_sampler_(npoints);
00241         bootstrapped_inputs[i]=inputs[index];
00242         bootstrapped_outputs[i]=outputs[index];
00243         if(poob_indices_)
00244             poob_indices_->back().push_back(index); //store index of point for later OOB estimates
00245     }
00246 }
00247 
00248 unsigned  clsfy_random_forest_builder::select_nbranch_params(unsigned ndims) const
00249 {
00250     unsigned nbranch_params=1;
00251     if(ndims>2)
00252     {
00253         double dnbranch_params=vcl_sqrt(double(ndims)+0.1); //round up if close 
00254         nbranch_params=unsigned (dnbranch_params); //round 
00255     }
00256     return nbranch_params;
00257 }
00258 
00259 void clsfy_random_forest_builder::seed_sampler(unsigned long seed)
00260 {
00261 
00262     random_sampler_.reseed(seed);
00263 }
00264  
00265 unsigned long clsfy_random_forest_builder::get_tree_builder_seed() const
00266 {
00267 
00268     //generate some bytes from the original seeded random number generator
00269     unsigned long N=256;
00270     unsigned nbytes=sizeof(unsigned long);
00271     vcl_vector<vxl_byte> seedAsBytes(nbytes,1);
00272 
00273     for(unsigned ib=0;ib<nbytes;++ib)
00274     {
00275         seedAsBytes[ib]=static_cast<vxl_byte>(random_sampler_(N));
00276     }
00277 
00278     unsigned long* pSeed=reinterpret_cast<unsigned long*>(&seedAsBytes[0]);
00279     return *pSeed;
00280 }