Go to the documentation of this file.00001
00002 #include "clsfy_binary_threshold_1d_gini_builder.h"
00003
00004
00005
00006
00007 #include <vcl_iostream.h>
00008 #include <vcl_string.h>
00009 #include <vcl_cassert.h>
00010 #include <vsl/vsl_binary_loader.h>
00011 #include <vnl/vnl_double_2.h>
00012 #include <clsfy/clsfy_builder_1d.h>
00013 #include <clsfy/clsfy_binary_threshold_1d.h>
00014 #include <vcl_algorithm.h>
00015
00016
00017
00018
00019
00020
00021
00022
00023 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder()
00024 {
00025 }
00026
00027
00028
00029 clsfy_binary_threshold_1d_gini_builder::~clsfy_binary_threshold_1d_gini_builder()
00030 {
00031 }
00032
00033
00034
00035 short clsfy_binary_threshold_1d_gini_builder::version_no() const
00036 {
00037 return 1;
00038 }
00039
00040
00041
00042
00043 clsfy_classifier_1d* clsfy_binary_threshold_1d_gini_builder::new_classifier() const
00044 {
00045 return new clsfy_binary_threshold_1d();
00046 }
00047
00048
00049
00050
00051
00052
00053
00054
00055 double clsfy_binary_threshold_1d_gini_builder::build_gini(clsfy_classifier_1d& classifier,
00056 const vnl_vector<double>& inputs,
00057 const vcl_vector<unsigned> &outputs) const
00058 {
00059 assert(classifier.is_class("clsfy_binary_threshold_1d"));
00060
00061 unsigned n = inputs.size();
00062 assert ( outputs.size() == n );
00063
00064
00065 vcl_vector<vbl_triple<double,int,int> > data;
00066 data.reserve(n);
00067
00068
00069 vcl_vector<unsigned >::const_iterator classIter=outputs.begin();
00070 vnl_vector<double >::const_iterator inputIter=inputs.begin();
00071 vnl_vector<double >::const_iterator inputIterEnd=inputs.end();
00072 vbl_triple<double,int,int> t;
00073 unsigned i=0;
00074 while (inputIter != inputIterEnd)
00075 {
00076 t.first = *inputIter++;
00077 t.second=*classIter++;
00078 t.third = i++;
00079 data.push_back(t);
00080 }
00081
00082 assert(i==inputs.size());
00083
00084 vcl_sort(data.begin(),data.end());
00085 return build_gini_from_sorted_data(static_cast<clsfy_classifier_1d&>(classifier), data);
00086 }
00087
00088
00089
00090
00091
00092
00093 double clsfy_binary_threshold_1d_gini_builder::build_gini_from_sorted_data(
00094 clsfy_classifier_1d& classifier,
00095 const vcl_vector<vbl_triple<double,int,int> >& data) const
00096 {
00097
00098
00099
00100
00101
00102
00103 const double epsilon=1.0E-20;
00104 if (vcl_fabs(data.front().first-data.back().first)<epsilon)
00105 {
00106 vcl_cerr<<"WARNING - clsfy_binary_threshold_1d_gini_builder::build_from_sorted_data - homogeneous data - cannot split\n";
00107 int polarity=1;
00108 double threshold=data[0].first;
00109 vnl_double_2 params(polarity, threshold*polarity);
00110 classifier.set_params(params.as_vector());
00111 return 1.0;
00112 }
00113
00114 unsigned int ntot=data.size();
00115 double dntot=double (ntot);
00116 vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIter=data.begin();
00117 vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterEnd=data.end();
00118 unsigned n0Tot=0;
00119 unsigned n1Tot=0;
00120 while (dataIter != dataIterEnd)
00121 {
00122 if (dataIter->second==0)
00123 ++n0Tot;
00124 else
00125 ++n1Tot;
00126 ++dataIter;
00127 }
00128
00129 double parentImp=0.0;
00130
00131
00132 double p=double (n0Tot)/dntot;
00133 parentImp=2.0*p*(1-p);
00134
00135 dataIter=data.begin();
00136 double s=dataIter->first-epsilon;
00137 double deltaImpBest= -1.0;
00138 double sbest=s;
00139 double ibest=0;
00140
00141 unsigned nL0=0;
00142 unsigned nL1=0;
00143 unsigned nR0=n0Tot;
00144 unsigned nR1=n1Tot;
00145 double parity=1.0;
00146 while (dataIter != dataIterEnd)
00147 {
00148 s=dataIter->first;
00149 vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterNext=dataIter;
00150
00151
00152 while (dataIterNext != dataIterEnd && (dataIterNext->first-s)<epsilon)
00153 {
00154 if (dataIterNext->second==0)
00155 {
00156 ++nL0;
00157 --nR0;
00158 }
00159 else
00160 {
00161 ++nL1;
00162 --nR1;
00163 }
00164 ++dataIterNext;
00165 }
00166
00167 unsigned nLTot=nL0+nL1;
00168 unsigned nRTot=nR0+nR1;
00169 double probL=double(nL0)/double(nLTot);
00170 double probR=double(nR1)/double(nRTot);
00171
00172 double impL=2.0*probL*(1-probL);
00173 double impR=2.0*probR*(1-probR);
00174
00175
00176 double pL=double (nLTot)/dntot;
00177 double pR=1.0-pL;
00178
00179 double deltaImp=parentImp-(pL*impL + pR*impR);
00180 if (deltaImp>deltaImpBest)
00181 {
00182 deltaImpBest=deltaImp;
00183 sbest=s;
00184 if (nR1>=nL1)
00185 parity=1;
00186 else
00187 parity=-1;
00188 }
00189
00190 dataIter=dataIterNext;
00191 }
00192
00193 double threshold=sbest;
00194
00195
00196 vnl_double_2 params(parity, threshold*parity);
00197 classifier.set_params(params.as_vector());
00198 return -deltaImpBest;
00199 }
00200
00201
00202
00203 vcl_string clsfy_binary_threshold_1d_gini_builder::is_a() const
00204 {
00205 return vcl_string("clsfy_binary_threshold_1d_gini_builder");
00206 }
00207
00208 bool clsfy_binary_threshold_1d_gini_builder::is_class(vcl_string const& s) const
00209 {
00210 return s == clsfy_binary_threshold_1d_gini_builder::is_a() || clsfy_builder_1d::is_class(s);
00211 }
00212
00213 # if 0
00214
00215
00216
00217
00218 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder(
00219 const clsfy_binary_threshold_1d_gini_builder& new_b) :
00220 data_ptr_(0)
00221 {
00222 *this = new_b;
00223 }
00224
00225
00226
00227
00228 clsfy_binary_threshold_1d_gini_builder&
00229 clsfy_binary_threshold_1d_gini_builder::operator=(const clsfy_binary_threshold_1d_gini_builder& new_b)
00230 {
00231 if (&new_b==this) return *this;
00232
00233 static_cast<clsfy_binary_threshold_1d_builder&>(*this)=
00234 static_cast<const clsfy_binary_threshold_1d_builder&> (new_b);
00235
00236 return *this;
00237 }
00238 #endif
00239
00240
00241
00242 void clsfy_binary_threshold_1d_gini_builder::print_summary(vcl_ostream& os) const
00243 {
00244 os<<"clsfy_binary_threshold_1d_gini_builder"<<vcl_endl;
00245 }
00246
00247
00248
00249
00250 clsfy_builder_1d* clsfy_binary_threshold_1d_gini_builder::clone() const
00251 {
00252 return new clsfy_binary_threshold_1d_gini_builder(*this);
00253 }
00254
00255
00256
00257 void clsfy_binary_threshold_1d_gini_builder::b_write(vsl_b_ostream& bfs) const
00258 {
00259 vsl_b_write(bfs, version_no());
00260 clsfy_binary_threshold_1d_builder::b_write(bfs);
00261 }
00262
00263
00264
00265
00266 void clsfy_binary_threshold_1d_gini_builder::b_read(vsl_b_istream& bfs)
00267 {
00268 if (!bfs) return;
00269
00270 short version;
00271 vsl_b_read(bfs,version);
00272 switch (version)
00273 {
00274 case (1):
00275 clsfy_binary_threshold_1d_builder::b_read(bfs);
00276 break;
00277 default:
00278 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_gini_builder&)\n"
00279 << " Unknown version number "<< version << '\n';
00280 bfs.is().clear(vcl_ios::badbit);
00281 return;
00282 }
00283 }