Go to the documentation of this file.00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010 #include "clsfy_random_forest_builder.h"
00011 #include <vxl_config.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_string.h>
00014 #include <vcl_algorithm.h>
00015 #include <vcl_numeric.h>
00016 #include <vcl_iterator.h>
00017 #include <vcl_cassert.h>
00018 #include <vsl/vsl_binary_loader.h>
00019 #include <vnl/vnl_math.h>
00020 #include <mbl/mbl_stl.h>
00021 #include <mbl/mbl_data_array_wrapper.h>
00022 #include <clsfy/clsfy_binary_tree_builder.h>
00023 #include "clsfy_random_forest.h"
00024
00025
00026
00027 clsfy_random_forest_builder::clsfy_random_forest_builder():
00028 max_depth_(-1),min_node_size_(-1),
00029 ntrees_(100),
00030 calc_test_error_(true),
00031 poob_indices_(0)
00032 {
00033 unsigned long default_seed=123654987;
00034 seed_sampler(default_seed);
00035 }
00036
00037 clsfy_random_forest_builder::clsfy_random_forest_builder(unsigned ntrees,
00038 int max_depth,
00039 int min_node_size):
00040 max_depth_(max_depth),min_node_size_(min_node_size),
00041 ntrees_(ntrees),
00042 calc_test_error_(true),
00043 poob_indices_(0)
00044 {
00045 unsigned long default_seed=123654987;
00046 seed_sampler(default_seed);
00047 }
00048
00049 clsfy_random_forest_builder::~clsfy_random_forest_builder()
00050 {
00051 }
00052
00053
00054 short clsfy_random_forest_builder::version_no() const
00055 {
00056 return 1;
00057 }
00058
00059
00060
00061 vcl_string clsfy_random_forest_builder::is_a() const
00062 {
00063 return vcl_string("clsfy_random_forest_builder");
00064 }
00065
00066
00067
00068 bool clsfy_random_forest_builder::is_class(vcl_string const& s) const
00069 {
00070 return s == clsfy_random_forest_builder::is_a() || clsfy_builder_base::is_class(s);
00071 }
00072
00073
00074
00075 clsfy_builder_base* clsfy_random_forest_builder::clone() const
00076 {
00077 return new clsfy_random_forest_builder(*this);
00078 }
00079
00080
00081
00082 void clsfy_random_forest_builder::print_summary(vcl_ostream& os) const
00083 {
00084 os << "Num trees = "<<ntrees_<<"\tmax_depth = " << max_depth_;
00085 }
00086
00087
00088
00089 void clsfy_random_forest_builder::b_write(vsl_b_ostream& bfs) const
00090 {
00091 vsl_b_write(bfs, version_no());
00092 vsl_b_write(bfs, ntrees_);
00093 vsl_b_write(bfs, max_depth_);
00094 vsl_b_write(bfs, min_node_size_);
00095 vsl_b_write(bfs,calc_test_error_);
00096 vcl_cerr << "clsfy_random_forest_builder::b_write() NYI\n";
00097 }
00098
00099
00100
00101 void clsfy_random_forest_builder::b_read(vsl_b_istream& bfs)
00102 {
00103 if (!bfs) return;
00104
00105 short version;
00106 vsl_b_read(bfs,version);
00107 switch (version)
00108 {
00109 case (1):
00110 vsl_b_read(bfs, ntrees_);
00111 vsl_b_read(bfs, max_depth_);
00112 vsl_b_read(bfs, min_node_size_);
00113 vsl_b_read(bfs,calc_test_error_);
00114 break;
00115 default:
00116 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_random_forest_builder&)\n"
00117 << " Unknown version number "<< version << "\n";
00118 bfs.is().clear(vcl_ios::badbit);
00119 }
00120 }
00121
00122
00123
00124
00125
00126
00127
00128 double clsfy_random_forest_builder::build(clsfy_classifier_base& classifier,
00129 mbl_data_wrapper<vnl_vector<double> >& inputs,
00130 unsigned nClasses,
00131 const vcl_vector<unsigned> &outputs) const
00132 {
00133 assert(classifier.is_class("clsfy_random_forest"));
00134 assert(inputs.size()==outputs.size());
00135 assert(nClasses=1);
00136
00137
00138 clsfy_random_forest &random_forest = static_cast<clsfy_random_forest&>(classifier);
00139 unsigned npoints=inputs.size();
00140 vcl_vector<vnl_vector<double> > vin(npoints);
00141
00142 inputs.reset();
00143 unsigned i=0;
00144 do
00145 {
00146 vin[i++] = inputs.current();
00147 } while (inputs.next());
00148
00149 assert(i==inputs.size());
00150
00151 unsigned ndims=vin[0].size();
00152 int nbranch_params=select_nbranch_params(ndims);
00153
00154
00155 vcl_cout<<"npoints= "<<npoints<<"\tndims= "<<ndims<<vcl_endl;
00156 vcl_vector<unsigned> indices(ndims,0);
00157
00158 mbl_stl_increments(indices.begin(),indices.end(),0);
00159
00160
00161 random_forest.prune();
00162
00163 if(poob_indices_)
00164 {
00165 poob_indices_->clear();
00166 poob_indices_->reserve(ntrees_);
00167 }
00168
00169
00170 vcl_vector<vnl_vector<double> > bootstrapped_inputs;
00171 vcl_vector<unsigned > bootstrapped_outputs;
00172
00173 for(i=0;i<ntrees_;++i)
00174 {
00175 if(i %10 == 0)
00176 vcl_cout<<"Building tree "<<i<<vcl_endl;
00177
00178 select_data(vin,outputs,bootstrapped_inputs,bootstrapped_outputs);
00179
00180 clsfy_binary_tree_builder builder;
00181 builder.set_calc_test_error(false);
00182
00183 clsfy_classifier_base* pBaseClassifier=builder.new_classifier();
00184 clsfy_binary_tree* pTreeClassifier=dynamic_cast<clsfy_binary_tree*>(pBaseClassifier);
00185 assert(pTreeClassifier);
00186 builder.set_nbranch_params(nbranch_params);
00187
00188 unsigned long seed=get_tree_builder_seed();
00189
00190 builder.seed_sampler(seed);
00191
00192 builder.set_max_depth(max_depth_);
00193 builder.set_min_node_size(min_node_size_);
00194 mbl_data_array_wrapper<vnl_vector<double> > bootstrapped_inputs_mbl(bootstrapped_inputs);
00195
00196 builder.build(*pTreeClassifier,
00197 bootstrapped_inputs_mbl,
00198 1,
00199 bootstrapped_outputs);
00200
00201 mbl_cloneable_ptr<clsfy_classifier_base> treeClassifier(pTreeClassifier);
00202 random_forest.trees_.push_back(treeClassifier);
00203 }
00204
00205 if(calc_test_error_)
00206 return clsfy_test_error(classifier, inputs, outputs);
00207 else
00208 return 0.0;
00209
00210
00211 }
00212
00213
00214
00215 clsfy_classifier_base* clsfy_random_forest_builder::new_classifier() const
00216 {
00217 return new clsfy_random_forest();
00218 }
00219
00220
00221
00222 void clsfy_random_forest_builder::select_data(vcl_vector<vnl_vector<double> >& inputs,
00223 const vcl_vector<unsigned> &outputs,
00224 vcl_vector<vnl_vector<double> >& bootstrapped_inputs,
00225 vcl_vector<unsigned> & bootstrapped_outputs) const
00226 {
00227 unsigned npoints=inputs.size();
00228 bootstrapped_inputs.resize(npoints);
00229 bootstrapped_outputs.resize(npoints);
00230 unsigned ndims= inputs.front().size();
00231 double dn=double(npoints);
00232 if(poob_indices_)
00233 {
00234 poob_indices_->push_back(vcl_vector<unsigned>());
00235 poob_indices_->back().reserve(npoints);
00236 }
00237 for(unsigned i=0;i<npoints;++i)
00238 {
00239 bootstrapped_inputs[i].set_size(ndims);
00240 unsigned index=random_sampler_(npoints);
00241 bootstrapped_inputs[i]=inputs[index];
00242 bootstrapped_outputs[i]=outputs[index];
00243 if(poob_indices_)
00244 poob_indices_->back().push_back(index);
00245 }
00246 }
00247
00248 unsigned clsfy_random_forest_builder::select_nbranch_params(unsigned ndims) const
00249 {
00250 unsigned nbranch_params=1;
00251 if(ndims>2)
00252 {
00253 double dnbranch_params=vcl_sqrt(double(ndims)+0.1);
00254 nbranch_params=unsigned (dnbranch_params);
00255 }
00256 return nbranch_params;
00257 }
00258
00259 void clsfy_random_forest_builder::seed_sampler(unsigned long seed)
00260 {
00261
00262 random_sampler_.reseed(seed);
00263 }
00264
00265 unsigned long clsfy_random_forest_builder::get_tree_builder_seed() const
00266 {
00267
00268
00269 unsigned long N=256;
00270 unsigned nbytes=sizeof(unsigned long);
00271 vcl_vector<vxl_byte> seedAsBytes(nbytes,1);
00272
00273 for(unsigned ib=0;ib<nbytes;++ib)
00274 {
00275 seedAsBytes[ib]=static_cast<vxl_byte>(random_sampler_(N));
00276 }
00277
00278 unsigned long* pSeed=reinterpret_cast<unsigned long*>(&seedAsBytes[0]);
00279 return *pSeed;
00280 }