Go to the documentation of this file.00001
00002 #include "clsfy_binary_tree.h"
00003
00004
00005
00006
00007
00008 #include <vcl_string.h>
00009 #include <vcl_deque.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_iterator.h>
00012 #include <vcl_cmath.h>
00013 #include <vcl_cassert.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_vector_io.h>
00016 #include <vnl/io/vnl_io_vector.h>
00017 #include <mbl/mbl_stl.h>
00018
00019
00020 clsfy_binary_tree::clsfy_binary_tree(const clsfy_binary_tree& srcTree)
00021 {
00022 root_=cache_node_=0;
00023 copy(srcTree);
00024 }
00025
00026 clsfy_binary_tree& clsfy_binary_tree::operator=(const clsfy_binary_tree& srcTree)
00027 {
00028 if (&srcTree != this)
00029 {
00030 copy(srcTree);
00031 }
00032 return *this;
00033 }
00034
00035 void clsfy_binary_tree::copy(const clsfy_binary_tree& srcTree)
00036 {
00037 remove_tree(root_);
00038
00039 if (srcTree.root_)
00040 {
00041 root_ = new clsfy_binary_tree_node(0,srcTree.root_->op_);
00042 root_->prob_ = srcTree.root_->prob_;
00043 copy_children(srcTree.root_,root_);
00044 }
00045 else
00046 root_=0;
00047 cache_node_ = root_;
00048 }
00049
00050 void clsfy_binary_tree::copy_children(clsfy_binary_tree_node* pSrcNode,clsfy_binary_tree_node* pNode)
00051 {
00052 bool left=true;
00053 pNode->prob_ = pSrcNode->prob_;
00054 if (pSrcNode->left_child_)
00055 {
00056 pNode->add_child(pSrcNode->left_child_->op_,left);
00057 copy_children(pSrcNode->left_child_,
00058 pNode->left_child_);
00059 }
00060 if (pSrcNode->right_child_)
00061 {
00062 pNode->add_child(pSrcNode->right_child_->op_,!left);
00063 copy_children(pSrcNode->right_child_,
00064 pNode->right_child_);
00065 }
00066 }
00067
00068
00069
00070 unsigned clsfy_binary_tree::classify(const vnl_vector<double> &input) const
00071 {
00072 unsigned outClass=0;
00073
00074 clsfy_binary_tree_node* pNode=root_;
00075 if (!pNode)
00076 {
00077 vcl_cerr<<"WARNING - empty tree in clsfy_binary_tree::classify\n"
00078 <<"Return default classification zero\n";
00079 return 0;
00080 }
00081 clsfy_binary_tree_node* pChild=0;
00082 do
00083 {
00084 pNode->op_.set_data(input);
00085 unsigned indicator=pNode->op_.classify();
00086 if (indicator==0)
00087 {
00088 pChild=pNode->left_child_;
00089 }
00090 else
00091 {
00092 pChild=pNode->right_child_;
00093 }
00094 if (pChild)
00095 pNode=pChild;
00096 else
00097 {
00098 cache_node_ = pNode;
00099 outClass=(pNode->prob_>0.5 ? 1 : 0);
00100 }
00101 }while (pChild);
00102
00103 return outClass;
00104 }
00105
00106
00107
00108
00109 void clsfy_binary_tree::class_probabilities(vcl_vector<double>& outputs,
00110 vnl_vector<double>const& input) const
00111 {
00112 outputs.resize(1);
00113 unsigned dummy=classify(input);
00114 outputs[0] = cache_node_->prob_;
00115 }
00116
00117
00118
00119
00120 unsigned clsfy_binary_tree::n_dims() const
00121 {
00122 clsfy_binary_tree_node* pNode=root_;
00123 if (pNode)
00124 return pNode->op_.ndims();
00125 else
00126 return 0;
00127 }
00128
00129
00130
00131
00132 double clsfy_binary_tree::log_l(const vnl_vector<double> &input) const
00133 {
00134 vcl_vector<double > probs;
00135 class_probabilities(probs,input);
00136 double p1=probs[0];
00137 double p0=1-p1;
00138 const double epsilon=1.0E-8;
00139 if (p0<epsilon) p0=epsilon;
00140 double L=vcl_log(p1/p0);
00141
00142 return L;
00143 }
00144
00145
00146
00147
00148 vcl_string clsfy_binary_tree::is_a() const
00149 {
00150 return vcl_string("clsfy_binary_tree");
00151 }
00152
00153
00154
00155 bool clsfy_binary_tree::is_class(vcl_string const& s) const
00156 {
00157 return s == clsfy_binary_tree::is_a() || clsfy_classifier_base::is_class(s);
00158 }
00159
00160
00161
00162 short clsfy_binary_tree::version_no() const
00163 {
00164 return 1;
00165 }
00166
00167
00168
00169 clsfy_classifier_base* clsfy_binary_tree::clone() const
00170 {
00171 return new clsfy_binary_tree(*this);
00172 }
00173
00174
00175
00176 void clsfy_binary_tree::print_summary(vcl_ostream& os) const
00177 {
00178 }
00179
00180
00181
00182 void clsfy_binary_tree::b_write(vsl_b_ostream& bfs) const
00183 {
00184 vsl_b_write(bfs,version_no());
00185 int nodeId=0;
00186
00187 vcl_deque<clsfy_binary_tree_node*> stack;
00188 vcl_deque<clsfy_binary_tree_node*> outlist;
00189 vcl_vector<graph_rep> arcs;
00190 clsfy_binary_tree_node* pNode=root_;
00191
00192 stack.push_back(pNode);
00193 pNode->nodeId_=0;
00194 while (!stack.empty())
00195 {
00196 pNode=stack.front();
00197 stack.pop_front();
00198 outlist.push_back(pNode);
00199 graph_rep link;
00200 link.me=pNode->nodeId_;
00201 link.left_child = link.right_child = -1;
00202
00203 if (pNode)
00204 {
00205 if (pNode->left_child_)
00206 {
00207 stack.push_back(pNode->left_child_);
00208 pNode->left_child_->nodeId_= ++nodeId;
00209 link.left_child=nodeId;
00210 }
00211 if (pNode->right_child_)
00212 {
00213 stack.push_back(pNode->right_child_);
00214 pNode->right_child_->nodeId_= ++nodeId;
00215 link.right_child=nodeId;
00216 }
00217
00218 arcs.push_back(link);
00219 }
00220 }
00221
00222 unsigned N=outlist.size();
00223 vsl_b_write(bfs,N);
00224
00225 vcl_deque<clsfy_binary_tree_node*>::iterator outIter=outlist.begin();
00226 vcl_deque<clsfy_binary_tree_node*>::iterator outIterEnd=outlist.end();
00227 while (outIter != outIterEnd)
00228 {
00229 clsfy_binary_tree_node* pNode=*outIter;
00230 vsl_b_write(bfs,pNode->nodeId_);
00231 pNode->op_.b_write(bfs);
00232 vsl_b_write(bfs,pNode->prob_);
00233 ++outIter;
00234 }
00235
00236
00237 N=arcs.size();
00238 vsl_b_write(bfs,N);
00239
00240 vcl_vector<graph_rep>::iterator arcIter=arcs.begin();
00241 vcl_vector<graph_rep>::iterator arcIterEnd=arcs.end();
00242
00243 while (arcIter != arcIterEnd)
00244 {
00245 vsl_b_write(bfs,arcIter->me);
00246 vsl_b_write(bfs,arcIter->left_child);
00247 vsl_b_write(bfs,arcIter->right_child);
00248 ++arcIter;
00249 }
00250 }
00251
00252
00253
00254 void clsfy_binary_tree::b_read(vsl_b_istream& bfs)
00255 {
00256 if (!bfs) return;
00257
00258 remove_tree(root_);
00259 root_=0;
00260
00261 short version;
00262 vsl_b_read(bfs,version);
00263 switch (version)
00264 {
00265 case (1):
00266 {
00267 vcl_map<int,clsfy_binary_tree_node*> workmap;
00268 vcl_vector<graph_rep> arcs;
00269 clsfy_binary_tree_node* pNode=0;
00270
00271 clsfy_binary_tree_node* pNull=0;
00272 unsigned N;
00273 vsl_b_read(bfs,N);
00274 for (unsigned i=0;i<N;++i)
00275 {
00276 int nodeId=-1;
00277 vsl_b_read(bfs,nodeId);
00278 clsfy_binary_tree_op op;
00279 op.b_read(bfs);
00280 clsfy_binary_tree_node* pNode=new clsfy_binary_tree_node(pNull,op);
00281 pNode->nodeId_=nodeId;
00282 vsl_b_read(bfs,pNode->prob_);
00283 workmap[nodeId]=pNode;
00284 }
00285 vsl_b_read(bfs,N);
00286 arcs.reserve(N);
00287 for (unsigned i=0;i<N;++i)
00288 {
00289 graph_rep link;
00290 vsl_b_read(bfs,link.me);
00291 vsl_b_read(bfs,link.left_child);
00292 vsl_b_read(bfs,link.right_child);
00293 arcs.push_back(link);
00294 }
00295 root_=workmap[0];
00296 for (unsigned i=0;i<N;++i)
00297 {
00298 graph_rep link=arcs[i];
00299 if (link.me!= -1)
00300 {
00301 clsfy_binary_tree_node* parent=workmap[link.me];
00302 clsfy_binary_tree_node* left_child=0;
00303 clsfy_binary_tree_node* right_child=0;
00304 if (link.left_child != -1)
00305 left_child=workmap[link.left_child];
00306 if (link.right_child != -1)
00307 right_child=workmap[link.right_child];
00308
00309 if (!parent || parent->nodeId_ != link.me)
00310 {
00311 vcl_cerr<<"ERROR - Inconsistent parent in tree set up in clsfy_binary_tree::b_read\n";
00312 assert(0);
00313 }
00314 if ((link.left_child != -1) &&
00315 (!left_child || left_child->nodeId_ != link.left_child))
00316 {
00317 vcl_cerr<<"ERROR - Inconsistent left child in tree set up in clsfy_binary_tree::b_read\n";
00318 assert(0);
00319 }
00320 if ((link.right_child != -1) &&
00321 (!right_child || right_child->nodeId_ != link.right_child))
00322 {
00323 vcl_cerr<<"ERROR - Inconsistent right child in tree set up in clsfy_binary_tree::b_read\n";
00324 assert(0);
00325 }
00326
00327
00328 parent->left_child_=left_child;
00329 if (left_child)
00330 left_child->parent_=parent;
00331
00332 parent->right_child_=right_child;
00333 if (right_child)
00334 right_child->parent_=parent;
00335 }
00336 }
00337
00338
00339 assert(root_);
00340 vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIter =workmap.begin();
00341 vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIterEnd =workmap.end();
00342 while (nodeIter != nodeIterEnd)
00343 {
00344 clsfy_binary_tree_node* pNode=nodeIter->second;
00345 assert(pNode->nodeId_==nodeIter->first);
00346 if (pNode != root_)
00347 {
00348 assert(pNode->parent_);
00349 assert(pNode->parent_->left_child_==pNode ||
00350 pNode->parent_->right_child_ == pNode);
00351 }
00352 if (pNode->left_child_)
00353 assert(pNode->left_child_->parent_==pNode);
00354 if (pNode->right_child_)
00355 assert(pNode->right_child_->parent_==pNode);
00356
00357
00358 while (pNode->parent_)
00359 {
00360 pNode=pNode->parent_;
00361 }
00362 assert(pNode==root_);
00363
00364 ++nodeIter;
00365 }
00366 }
00367 break;
00368
00369 default:
00370 vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00371 << " Unknown version number "<< version << '\n';
00372 bfs.is().clear(vcl_ios::badbit);
00373 }
00374 }
00375
00376 clsfy_binary_tree::~clsfy_binary_tree()
00377 {
00378 remove_tree(root_);
00379 root_=0;
00380 }
00381
00382 void clsfy_binary_tree::remove_tree(clsfy_binary_tree_node* root)
00383 {
00384 vcl_deque<clsfy_binary_tree_node*> stack;
00385 vcl_deque<clsfy_binary_tree_node*> killset;
00386 stack.push_back(root);
00387 while (!stack.empty())
00388 {
00389 clsfy_binary_tree_node* pNode=stack.front();
00390 stack.pop_front();
00391
00392 if (pNode)
00393 {
00394 killset.push_back(pNode);
00395 if (pNode->left_child_)
00396 {
00397 stack.push_back(pNode->left_child_);
00398 }
00399 if (pNode->right_child_)
00400 {
00401 stack.push_back(pNode->right_child_);
00402 }
00403 }
00404 }
00405
00406 mbl_stl_clean(killset.begin(),killset.end());
00407 }
00408
00409 void clsfy_binary_tree::set_root( clsfy_binary_tree_node* root)
00410 {
00411 if ((root != root_) && root_)
00412 remove_tree(root_);
00413 root_=root;
00414 }
00415
00416
00417
00418
00419 clsfy_binary_tree_node* clsfy_binary_tree_node::create_child(const clsfy_binary_tree_op& op)
00420 {
00421 return new clsfy_binary_tree_node(this,op);
00422 }
00423
00424 void clsfy_binary_tree_op::b_write(vsl_b_ostream& bfs) const
00425 {
00426 vsl_b_write(bfs,version_no());
00427 vsl_b_write(bfs,data_index_);
00428 vsl_b_write(bfs,classifier_);
00429 }
00430
00431
00432 void clsfy_binary_tree_op::b_read(vsl_b_istream& bfs)
00433 {
00434 short version;
00435 vsl_b_read(bfs,version);
00436 if (version != 1)
00437 {
00438 vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00439 << " Unknown version number "<< version << '\n';
00440 bfs.is().clear(vcl_ios::badbit);
00441 }
00442 else
00443 {
00444 vsl_b_read(bfs,data_index_);
00445 vsl_b_read(bfs,classifier_);
00446 }
00447 }
00448