contrib/mul/clsfy/clsfy_binary_threshold_1d_gini_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_threshold_1d_gini_builder.cxx
00002 #include "clsfy_binary_threshold_1d_gini_builder.h"
00003 //:
00004 // \file
00005 // \author Martin Roberts
00006 
00007 #include <vcl_iostream.h>
00008 #include <vcl_string.h>
00009 #include <vcl_cassert.h>
00010 #include <vsl/vsl_binary_loader.h>
00011 #include <vnl/vnl_double_2.h>
00012 #include <clsfy/clsfy_builder_1d.h>
00013 #include <clsfy/clsfy_binary_threshold_1d.h>
00014 #include <vcl_algorithm.h>
00015 
00016 // Note this is used by clsfy_binary_tree_builder
00017 // Derived from clsfy_binary_threshold_1d_builder but uses a slightly different
00018 // interface to do the gini index optimisation, as tis returns the reduction
00019 // in the gini impurity (not classification error).
00020 
00021 //=======================================================================
00022 
00023 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder()
00024 {
00025 }
00026 
00027 //=======================================================================
00028 
00029 clsfy_binary_threshold_1d_gini_builder::~clsfy_binary_threshold_1d_gini_builder()
00030 {
00031 }
00032 
00033 //=======================================================================
00034 
00035 short clsfy_binary_threshold_1d_gini_builder::version_no() const
00036 {
00037     return 1;
00038 }
00039 
00040 
00041 //: Create empty classifier
00042 // Caller is responsible for deletion
00043 clsfy_classifier_1d* clsfy_binary_threshold_1d_gini_builder::new_classifier() const
00044 {
00045     return new clsfy_binary_threshold_1d();
00046 }
00047 
00048 //: Build a binary_threshold classifier
00049 //  Train classifier
00050 //  Selects parameters of classifier which best separate examples from two classes,
00051 // Uses the gini impurity index
00052 // Note it returns the -reduction in Gini impurity produced by the split
00053 // Not the misclassification rate
00054 // (i.e. but minimise as per error rate)
00055 double clsfy_binary_threshold_1d_gini_builder::build_gini(clsfy_classifier_1d& classifier,
00056                                                           const vnl_vector<double>&  inputs,
00057                                                           const vcl_vector<unsigned> &outputs) const
00058 {
00059     assert(classifier.is_class("clsfy_binary_threshold_1d"));
00060 
00061     unsigned n = inputs.size();
00062     assert ( outputs.size() == n );
00063 
00064     // create triples data, so can sort
00065     vcl_vector<vbl_triple<double,int,int> > data;
00066     data.reserve(n);
00067 
00068     //First just create sorted data
00069     vcl_vector<unsigned >::const_iterator classIter=outputs.begin();
00070     vnl_vector<double  >::const_iterator inputIter=inputs.begin();
00071     vnl_vector<double  >::const_iterator inputIterEnd=inputs.end();
00072     vbl_triple<double,int,int> t;
00073     unsigned i=0;
00074     while (inputIter != inputIterEnd)
00075     {
00076         t.first = *inputIter++;
00077         t.second=*classIter++;
00078         t.third = i++;
00079         data.push_back(t);
00080     }
00081 
00082     assert(i==inputs.size());
00083 
00084     vcl_sort(data.begin(),data.end());
00085     return build_gini_from_sorted_data(static_cast<clsfy_classifier_1d&>(classifier), data);
00086 }
00087 
00088 
00089 //: Train classifier, returning weighted error
00090 //   Assumes two classes
00091 //  Note that input "data" must be sorted to use this routine
00092 //Return -improvement in impurity (as normally these builders minimise)
00093 double clsfy_binary_threshold_1d_gini_builder::build_gini_from_sorted_data(
00094     clsfy_classifier_1d& classifier,
00095     const vcl_vector<vbl_triple<double,int,int> >& data) const
00096 {
00097     // here the triple consists of (value, class number, example index)
00098     // the example index specifies the weight of each example
00099     //
00100     // NB DATA must be sorted for this to work!!!!
00101 
00102     //Validate that the data is not homogeneous
00103     const double epsilon=1.0E-20;
00104     if (vcl_fabs(data.front().first-data.back().first)<epsilon)
00105     {
00106         vcl_cerr<<"WARNING - clsfy_binary_threshold_1d_gini_builder::build_from_sorted_data - homogeneous data - cannot split\n";
00107         int polarity=1;
00108         double threshold=data[0].first;
00109         vnl_double_2 params(polarity, threshold*polarity);
00110         classifier.set_params(params.as_vector());
00111         return 1.0;
00112     }
00113 
00114     unsigned int ntot=data.size();
00115     double dntot=double (ntot);
00116     vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIter=data.begin();
00117     vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterEnd=data.end();
00118     unsigned n0Tot=0;
00119     unsigned n1Tot=0;
00120     while (dataIter != dataIterEnd)
00121     {
00122         if (dataIter->second==0)
00123             ++n0Tot;
00124         else
00125             ++n1Tot;
00126         ++dataIter;
00127     }
00128 
00129     double parentImp=0.0;
00130     //Parent level impurity to start with
00131 
00132     double p=double (n0Tot)/dntot;
00133     parentImp=2.0*p*(1-p);
00134 
00135     dataIter=data.begin();
00136     double s=dataIter->first-epsilon;
00137     double deltaImpBest= -1.0; //initialise to split makes it worse
00138     double  sbest=s;
00139     double ibest=0;
00140     //Put none into left bin, all else go right
00141     unsigned nL0=0;
00142     unsigned nL1=0;
00143     unsigned nR0=n0Tot;
00144     unsigned nR1=n1Tot;
00145     double parity=1.0;
00146     while (dataIter != dataIterEnd)
00147     {
00148         s=dataIter->first;
00149         vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterNext=dataIter;
00150 
00151         //Increment till threshold increases (may have some same data values)
00152         while (dataIterNext != dataIterEnd && (dataIterNext->first-s)<epsilon)
00153         {
00154             if (dataIterNext->second==0)
00155             {
00156                 ++nL0;
00157                 --nR0;
00158             }
00159             else
00160             {
00161                 ++nL1;
00162                 --nR1;
00163             }
00164             ++dataIterNext;
00165         }
00166 
00167         unsigned nLTot=nL0+nL1;
00168         unsigned nRTot=nR0+nR1;
00169         double probL=double(nL0)/double(nLTot);
00170         double probR=double(nR1)/double(nRTot);
00171         //Two-class Gini index for left and right
00172         double impL=2.0*probL*(1-probL);
00173         double impR=2.0*probR*(1-probR);
00174 
00175         //Proportional weights
00176         double pL=double (nLTot)/dntot;
00177         double pR=1.0-pL;
00178 
00179         double deltaImp=parentImp-(pL*impL + pR*impR);
00180         if (deltaImp>deltaImpBest)
00181         {
00182             deltaImpBest=deltaImp;
00183             sbest=s;
00184             if (nR1>=nL1) //More class 1 are going above thresh
00185                 parity=1;
00186             else
00187                 parity=-1; //Reverse sign as more class one are going below thresh
00188         }
00189 
00190         dataIter=dataIterNext;
00191     }
00192 
00193     double threshold=sbest;
00194 
00195     // pass parameters to classifier
00196     vnl_double_2 params(parity, threshold*parity);
00197     classifier.set_params(params.as_vector());
00198     return -deltaImpBest;
00199 }
00200 
00201 //=======================================================================
00202 
00203 vcl_string clsfy_binary_threshold_1d_gini_builder::is_a() const
00204 {
00205   return vcl_string("clsfy_binary_threshold_1d_gini_builder");
00206 }
00207 
00208 bool clsfy_binary_threshold_1d_gini_builder::is_class(vcl_string const& s) const
00209 {
00210   return s == clsfy_binary_threshold_1d_gini_builder::is_a() || clsfy_builder_1d::is_class(s);
00211 }
00212 
00213 # if 0
00214 //=======================================================================
00215 
00216 
00217 // required if data stored on the heap is present in this derived class
00218 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder(
00219                              const clsfy_binary_threshold_1d_gini_builder& new_b) :
00220   data_ptr_(0)
00221 {
00222   *this = new_b;
00223 }
00224 
00225 //=======================================================================
00226 
00227 // required if data stored on the heap is present in this derived class
00228 clsfy_binary_threshold_1d_gini_builder&
00229 clsfy_binary_threshold_1d_gini_builder::operator=(const clsfy_binary_threshold_1d_gini_builder& new_b)
00230 {
00231     if (&new_b==this) return *this;
00232 
00233     static_cast<clsfy_binary_threshold_1d_builder&>(*this)=
00234         static_cast<const clsfy_binary_threshold_1d_builder&> (new_b);
00235 
00236     return *this;
00237 }
00238 #endif
00239 //=======================================================================
00240 
00241 // required if data is present in this base class
00242 void clsfy_binary_threshold_1d_gini_builder::print_summary(vcl_ostream& os) const
00243 {
00244     os<<"clsfy_binary_threshold_1d_gini_builder"<<vcl_endl;
00245 }
00246 
00247 //=======================================================================
00248 
00249 
00250 clsfy_builder_1d* clsfy_binary_threshold_1d_gini_builder::clone() const
00251 {
00252     return new clsfy_binary_threshold_1d_gini_builder(*this);
00253 }
00254 //=======================================================================
00255 
00256 // required if data is present in this base class
00257 void clsfy_binary_threshold_1d_gini_builder::b_write(vsl_b_ostream& bfs) const
00258 {
00259     vsl_b_write(bfs, version_no());
00260     clsfy_binary_threshold_1d_builder::b_write(bfs);
00261 }
00262 
00263 //=======================================================================
00264 
00265   // required if data is present in this base class
00266 void clsfy_binary_threshold_1d_gini_builder::b_read(vsl_b_istream& bfs)
00267 {
00268     if (!bfs) return;
00269 
00270     short version;
00271     vsl_b_read(bfs,version);
00272     switch (version)
00273     {
00274         case (1):
00275             clsfy_binary_threshold_1d_builder::b_read(bfs);
00276             break;
00277         default:
00278             vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_gini_builder&)\n"
00279                      << "           Unknown version number "<< version << '\n';
00280             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00281             return;
00282     }
00283 }