contrib/mul/clsfy/clsfy_binary_tree.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree.cxx
00002 #include "clsfy_binary_tree.h"
00003 //:
00004 // \file
00005 // \brief Binary tree classifier
00006 // \author Martin Roberts
00007 
00008 #include <vcl_string.h>
00009 #include <vcl_deque.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_iterator.h>
00012 #include <vcl_cmath.h>
00013 #include <vcl_cassert.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_vector_io.h>
00016 #include <vnl/io/vnl_io_vector.h>
00017 #include <mbl/mbl_stl.h>
00018 
00019 
00020 clsfy_binary_tree::clsfy_binary_tree(const clsfy_binary_tree& srcTree)
00021 {
00022     root_=cache_node_=0;
00023     copy(srcTree);
00024 }
00025 
00026 clsfy_binary_tree& clsfy_binary_tree::operator=(const clsfy_binary_tree& srcTree)
00027 {
00028     if (&srcTree != this)
00029     {
00030         copy(srcTree);
00031     }
00032     return *this;
00033 }
00034 
00035 void clsfy_binary_tree::copy(const clsfy_binary_tree& srcTree)
00036 {
00037     remove_tree(root_);
00038     //Then copy into the classifier
00039     if (srcTree.root_)
00040     {
00041         root_ = new clsfy_binary_tree_node(0,srcTree.root_->op_);
00042         root_->prob_ = srcTree.root_->prob_;
00043         copy_children(srcTree.root_,root_);
00044     }
00045     else
00046         root_=0;
00047     cache_node_ = root_;
00048 }
00049 
00050 void clsfy_binary_tree::copy_children(clsfy_binary_tree_node* pSrcNode,clsfy_binary_tree_node* pNode)
00051 {
00052     bool left=true;
00053     pNode->prob_ = pSrcNode->prob_;
00054     if (pSrcNode->left_child_)
00055     {
00056         pNode->add_child(pSrcNode->left_child_->op_,left);
00057         copy_children(pSrcNode->left_child_,
00058                       pNode->left_child_);
00059     }
00060     if (pSrcNode->right_child_)
00061     {
00062         pNode->add_child(pSrcNode->right_child_->op_,!left);
00063         copy_children(pSrcNode->right_child_,
00064                       pNode->right_child_);
00065     }
00066 }
00067 
00068 //=======================================================================
00069 //: Return the classification of the given probe vector.
00070 unsigned clsfy_binary_tree::classify(const vnl_vector<double> &input) const
00071 {
00072     unsigned outClass=0;
00073     //Traverse the tree
00074     clsfy_binary_tree_node* pNode=root_;
00075     if (!pNode)
00076     {
00077         vcl_cerr<<"WARNING - empty tree in clsfy_binary_tree::classify\n"
00078                 <<"Return default classification zero\n";
00079         return 0;
00080     }
00081     clsfy_binary_tree_node* pChild=0;
00082     do //Keep dropping down the tree till reach base level
00083     {
00084         pNode->op_.set_data(input);
00085         unsigned indicator=pNode->op_.classify();
00086         if (indicator==0)
00087         {
00088             pChild=pNode->left_child_;
00089         }
00090         else
00091         {
00092             pChild=pNode->right_child_;
00093         }
00094         if (pChild)
00095             pNode=pChild;
00096         else
00097         {
00098             cache_node_ = pNode; //Store final node (in case probability accessed)
00099             outClass=(pNode->prob_>0.5 ? 1 : 0);
00100         }
00101     }while (pChild);
00102 
00103     return outClass;
00104 }
00105 
00106 //=======================================================================
00107 //: Return a probability like value that the input being in each class.
00108 // output(i) i<<nClasses, contains the probability that the input is in class i
00109 void clsfy_binary_tree::class_probabilities(vcl_vector<double>& outputs,
00110                                             vnl_vector<double>const& input) const
00111 {
00112     outputs.resize(1);
00113     unsigned dummy=classify(input);
00114     outputs[0] = cache_node_->prob_;
00115 }
00116 
00117 
00118 //=======================================================================
00119 //: The dimensionality of input vectors.
00120 unsigned clsfy_binary_tree::n_dims() const
00121 {
00122     clsfy_binary_tree_node* pNode=root_;
00123     if (pNode)
00124         return pNode->op_.ndims();
00125     else
00126         return 0;
00127 }
00128 
00129 //=======================================================================
00130 //: This value has properties of a Log likelihood of being in class (binary classifiers only)
00131 // class probability = exp(logL) / (1+exp(logL))
00132 double clsfy_binary_tree::log_l(const vnl_vector<double> &input) const
00133 {
00134     vcl_vector<double > probs;
00135     class_probabilities(probs,input);
00136     double p1=probs[0];
00137     double p0=1-p1;
00138     const double epsilon=1.0E-8;
00139     if (p0<epsilon) p0=epsilon;
00140     double L=vcl_log(p1/p0);
00141 
00142     return L;
00143 }
00144 
00145 
00146 //=======================================================================
00147 
00148 vcl_string clsfy_binary_tree::is_a() const
00149 {
00150     return vcl_string("clsfy_binary_tree");
00151 }
00152 
00153 //=======================================================================
00154 
00155 bool clsfy_binary_tree::is_class(vcl_string const& s) const
00156 {
00157     return s == clsfy_binary_tree::is_a() || clsfy_classifier_base::is_class(s);
00158 }
00159 
00160 //=======================================================================
00161 
00162 short clsfy_binary_tree::version_no() const
00163 {
00164     return 1;
00165 }
00166 
00167 //=======================================================================
00168 
00169 clsfy_classifier_base* clsfy_binary_tree::clone() const
00170 {
00171     return new clsfy_binary_tree(*this);
00172 }
00173 
00174 //=======================================================================
00175 
00176 void clsfy_binary_tree::print_summary(vcl_ostream& os) const
00177 {
00178 }
00179 
00180 //=======================================================================
00181 
00182 void clsfy_binary_tree::b_write(vsl_b_ostream& bfs) const
00183 {
00184     vsl_b_write(bfs,version_no());
00185     int nodeId=0; //used numeric ids for parent child relations
00186     // -1 means none
00187     vcl_deque<clsfy_binary_tree_node*> stack;
00188     vcl_deque<clsfy_binary_tree_node*> outlist;
00189     vcl_vector<graph_rep> arcs;
00190     clsfy_binary_tree_node* pNode=root_;
00191 
00192     stack.push_back(pNode);
00193     pNode->nodeId_=0;
00194     while (!stack.empty())
00195     {
00196         pNode=stack.front();
00197         stack.pop_front();
00198         outlist.push_back(pNode);
00199         graph_rep link;
00200         link.me=pNode->nodeId_;
00201         link.left_child = link.right_child = -1;
00202 
00203         if (pNode)
00204         {
00205             if (pNode->left_child_)
00206             {
00207                 stack.push_back(pNode->left_child_);
00208                 pNode->left_child_->nodeId_= ++nodeId;
00209                 link.left_child=nodeId;
00210             }
00211             if (pNode->right_child_)
00212             {
00213                 stack.push_back(pNode->right_child_);
00214                 pNode->right_child_->nodeId_= ++nodeId;
00215                 link.right_child=nodeId;
00216             }
00217 
00218             arcs.push_back(link);
00219         }
00220     }
00221 
00222     unsigned N=outlist.size();
00223     vsl_b_write(bfs,N);
00224 
00225     vcl_deque<clsfy_binary_tree_node*>::iterator outIter=outlist.begin();
00226     vcl_deque<clsfy_binary_tree_node*>::iterator outIterEnd=outlist.end();
00227     while (outIter != outIterEnd)
00228     {
00229         clsfy_binary_tree_node* pNode=*outIter;
00230         vsl_b_write(bfs,pNode->nodeId_);
00231         pNode->op_.b_write(bfs);
00232         vsl_b_write(bfs,pNode->prob_);
00233         ++outIter;
00234     }
00235 
00236     //Now write out the links graph
00237     N=arcs.size();
00238     vsl_b_write(bfs,N);
00239 
00240     vcl_vector<graph_rep>::iterator arcIter=arcs.begin();
00241     vcl_vector<graph_rep>::iterator arcIterEnd=arcs.end();
00242 
00243     while (arcIter != arcIterEnd)
00244     {
00245         vsl_b_write(bfs,arcIter->me);
00246         vsl_b_write(bfs,arcIter->left_child);
00247         vsl_b_write(bfs,arcIter->right_child);
00248         ++arcIter;
00249     }
00250 }
00251 
00252 //=======================================================================
00253 
00254 void clsfy_binary_tree::b_read(vsl_b_istream& bfs)
00255 {
00256     if (!bfs) return;
00257 
00258     remove_tree(root_);
00259     root_=0;
00260 
00261     short version;
00262     vsl_b_read(bfs,version);
00263     switch (version)
00264     {
00265         case (1):
00266         {
00267             vcl_map<int,clsfy_binary_tree_node*> workmap;
00268             vcl_vector<graph_rep> arcs;
00269             clsfy_binary_tree_node* pNode=0;
00270 
00271             clsfy_binary_tree_node* pNull=0;
00272             unsigned N;
00273             vsl_b_read(bfs,N);
00274             for (unsigned i=0;i<N;++i)
00275             {
00276                 int nodeId=-1;
00277                 vsl_b_read(bfs,nodeId);
00278                 clsfy_binary_tree_op op;
00279                 op.b_read(bfs);
00280                 clsfy_binary_tree_node* pNode=new clsfy_binary_tree_node(pNull,op);
00281                 pNode->nodeId_=nodeId;
00282                 vsl_b_read(bfs,pNode->prob_);
00283                 workmap[nodeId]=pNode;
00284             }
00285             vsl_b_read(bfs,N);
00286             arcs.reserve(N);
00287             for (unsigned i=0;i<N;++i)
00288             {
00289                 graph_rep link;
00290                 vsl_b_read(bfs,link.me);
00291                 vsl_b_read(bfs,link.left_child);
00292                 vsl_b_read(bfs,link.right_child);
00293                 arcs.push_back(link);
00294             }
00295             root_=workmap[0];
00296             for (unsigned i=0;i<N;++i)
00297             {
00298                 graph_rep link=arcs[i];
00299                 if (link.me!= -1)
00300                 {
00301                     clsfy_binary_tree_node* parent=workmap[link.me];
00302                     clsfy_binary_tree_node* left_child=0;
00303                     clsfy_binary_tree_node* right_child=0;
00304                     if (link.left_child != -1)
00305                         left_child=workmap[link.left_child];
00306                     if (link.right_child != -1)
00307                         right_child=workmap[link.right_child];
00308 
00309                     if (!parent || parent->nodeId_ != link.me)
00310                     {
00311                         vcl_cerr<<"ERROR - Inconsistent parent in tree set up in clsfy_binary_tree::b_read\n";
00312                         assert(0);
00313                     }
00314                     if ((link.left_child != -1) &&
00315                         (!left_child || left_child->nodeId_ != link.left_child))
00316                                         {
00317                         vcl_cerr<<"ERROR - Inconsistent left child in tree set up in clsfy_binary_tree::b_read\n";
00318                         assert(0);
00319                     }
00320                     if ((link.right_child != -1) &&
00321                         (!right_child || right_child->nodeId_ != link.right_child))
00322                                         {
00323                         vcl_cerr<<"ERROR - Inconsistent right child in tree set up in clsfy_binary_tree::b_read\n";
00324                         assert(0);
00325                     }
00326 
00327                     //And link these into the tree
00328                     parent->left_child_=left_child;
00329                     if (left_child)
00330                         left_child->parent_=parent;
00331 
00332                     parent->right_child_=right_child;
00333                     if (right_child)
00334                         right_child->parent_=parent;
00335                 }
00336             }
00337 
00338             //Validate the tree
00339             assert(root_);
00340             vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIter =workmap.begin();
00341             vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIterEnd =workmap.end();
00342             while (nodeIter != nodeIterEnd)
00343             {
00344                 clsfy_binary_tree_node* pNode=nodeIter->second;
00345                 assert(pNode->nodeId_==nodeIter->first);
00346                 if (pNode != root_)
00347                 {
00348                     assert(pNode->parent_);
00349                     assert(pNode->parent_->left_child_==pNode ||
00350                            pNode->parent_->right_child_ == pNode);
00351                 }
00352                 if (pNode->left_child_)
00353                     assert(pNode->left_child_->parent_==pNode);
00354                 if (pNode->right_child_)
00355                     assert(pNode->right_child_->parent_==pNode);
00356 
00357                 //Check all nodes connect back up to root
00358                 while (pNode->parent_)
00359                 {
00360                     pNode=pNode->parent_;
00361                 }
00362                 assert(pNode==root_);
00363 
00364                 ++nodeIter;
00365             }
00366         }
00367         break;
00368 
00369         default:
00370             vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00371                      << "           Unknown version number "<< version << '\n';
00372             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00373     }
00374 }
00375 
00376 clsfy_binary_tree::~clsfy_binary_tree()
00377 {
00378     remove_tree(root_);
00379     root_=0;
00380 }
00381 
00382 void  clsfy_binary_tree::remove_tree(clsfy_binary_tree_node* root)
00383 {
00384     vcl_deque<clsfy_binary_tree_node*> stack;
00385     vcl_deque<clsfy_binary_tree_node*> killset;
00386     stack.push_back(root);
00387     while (!stack.empty())
00388     {
00389         clsfy_binary_tree_node* pNode=stack.front();
00390         stack.pop_front();
00391 
00392         if (pNode)
00393         {
00394             killset.push_back(pNode);
00395             if (pNode->left_child_)
00396             {
00397                 stack.push_back(pNode->left_child_);
00398             }
00399             if (pNode->right_child_)
00400             {
00401                 stack.push_back(pNode->right_child_);
00402             }
00403         }
00404     }
00405 
00406     mbl_stl_clean(killset.begin(),killset.end());
00407 }
00408 
00409 void clsfy_binary_tree::set_root(  clsfy_binary_tree_node* root)
00410 {
00411     if ((root != root_) && root_)
00412         remove_tree(root_);
00413     root_=root;
00414 }
00415 
00416 
00417 //--------------- HELPER CLASSES---------------------------------------------------------
00418 
00419 clsfy_binary_tree_node* clsfy_binary_tree_node::create_child(const clsfy_binary_tree_op& op)
00420 {
00421     return new clsfy_binary_tree_node(this,op);
00422 }
00423 
00424 void clsfy_binary_tree_op::b_write(vsl_b_ostream& bfs) const
00425 {
00426     vsl_b_write(bfs,version_no());
00427     vsl_b_write(bfs,data_index_);
00428     vsl_b_write(bfs,classifier_);
00429 }
00430 
00431 //: Load the class from a Binary File Stream
00432 void clsfy_binary_tree_op::b_read(vsl_b_istream& bfs)
00433 {
00434     short version;
00435     vsl_b_read(bfs,version);
00436     if (version != 1)
00437     {
00438         vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00439                  << "           Unknown version number "<< version << '\n';
00440         bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00441     }
00442     else
00443     {
00444         vsl_b_read(bfs,data_index_);
00445         vsl_b_read(bfs,classifier_);
00446     }
00447 }
00448