contrib/mul/clsfy/clsfy_binary_tree.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree.h
00002 #ifndef clsfy_binary_tree_h_
00003 #define clsfy_binary_tree_h_
00004 //:
00005 // \file
00006 // \brief Binary tree classifier
00007 // \author Martin Roberts
00008 #include <clsfy/clsfy_classifier_base.h>
00009 #include <clsfy/clsfy_binary_threshold_1d.h>
00010 #include <vcl_iosfwd.h>
00011 
00012 
00013 //: One node of a binary tree classifier - wrapper round clsfy_binary_threshold_1d
00014 //  Needs also to store the data feature index associated with the node
00015 //  Then it calls its binary classifier for that node
00016 //  Returns class zero if s_*x[i]<threshold_
00017 
00018 class clsfy_binary_tree_op
00019 {
00020  protected:
00021     //Index within data of variable used at this node (set to -1 if none assigned)
00022     int data_index_;
00023     const vnl_vector<double>* data_ptr_;
00024     clsfy_binary_threshold_1d classifier_;
00025     
00026       
00027  public:
00028 
00029   clsfy_binary_tree_op():data_index_(-1),data_ptr_(0) {}
00030     clsfy_binary_tree_op(const vnl_vector<double>* data_ptr,
00031                          int data_index=-1):
00032     data_ptr_(data_ptr),data_index_(data_index) {}
00033 
00034   clsfy_binary_threshold_1d& classifier() {return classifier_;}
00035   unsigned data_index() const {return data_index_;}
00036   void set_data_index(unsigned index) {data_index_=index;}
00037   void set_data_ptr(const vnl_vector<double>* data_ptr)
00038   {data_ptr_= data_ptr;}
00039 
00040   //: Return reference to data - NB throws std::bad_cast if null
00041   const vnl_vector<double >& data() const {return *data_ptr_;}
00042 
00043   void set_data(const vnl_vector<double >& inputs) {data_ptr_=&inputs;}
00044   //: Return value
00045   double val() const {return (*data_ptr_)[data_index_];}
00046 
00047   //: Classify
00048   unsigned classify() {return classifier_.classify(val());}
00049 
00050   unsigned ndims() {return (data_ptr_ ? data_ptr_->size() : 0 );} 
00051       
00052   //: Save class to a binary File Stream
00053   void b_write(vsl_b_ostream& bfs) const;
00054 
00055   //: Load the class from a Binary File Stream
00056   void b_read(vsl_b_istream& bfs);
00057 
00058   short version_no() const {return 1;}
00059 
00060 };
00061 
00062 
00063 class clsfy_binary_tree_node {
00064     int nodeId_;
00065     clsfy_binary_tree_node* parent_;
00066     clsfy_binary_tree_node* left_child_;
00067     clsfy_binary_tree_node* right_child_;
00068     clsfy_binary_tree_op op_;
00069     double prob_; //Only used on terminal nodes
00070   public:
00071     
00072   clsfy_binary_tree_node(clsfy_binary_tree_node* parent,
00073                          const clsfy_binary_tree_op& op):
00074     nodeId_(-1),parent_(parent),left_child_(0),right_child_(0),op_(op),prob_(0.5)
00075     {
00076     }
00077     
00078 
00079     virtual clsfy_binary_tree_node* create_child(const clsfy_binary_tree_op& op);
00080     void add_child(const clsfy_binary_tree_op& op,bool bLeft)
00081     {
00082         clsfy_binary_tree_node* child=create_child(op);
00083         if(bLeft)
00084             left_child_=child;
00085         else
00086             right_child_=child;                
00087     }
00088 
00089     //Note the owning classifier removes the tree - beware as once deleted its children
00090     //may be inaccessible for deletion
00091     virtual ~clsfy_binary_tree_node() {}
00092     
00093     friend class clsfy_binary_tree;
00094     friend class clsfy_binary_tree_builder;
00095 };
00096 
00097 
00098 //: A binary tree classifer
00099 // Drop down the tree using a binary threshold on a specific variable from the set at each node.
00100 // Branch left for one classification, right for the other
00101 // Eventually a node is reached with no children and that node's
00102 // binary threshold classification is returned
00103 
00104 class clsfy_binary_tree : public clsfy_classifier_base
00105 {
00106 
00107   public:
00108 
00109     struct graph_rep
00110     {
00111         int me;
00112         int left_child;
00113         int right_child;
00114     };
00115 
00116     
00117   //: Constructor
00118   clsfy_binary_tree(): root_(0),cache_node_(0) {}
00119 
00120     virtual ~clsfy_binary_tree();
00121     
00122     clsfy_binary_tree(const clsfy_binary_tree& srcTree);
00123     
00124     clsfy_binary_tree& operator=(const clsfy_binary_tree& srcTree);
00125     
00126     static void remove_tree(clsfy_binary_tree_node* root);
00127   //: Return the classification of the given probe vector.
00128   virtual unsigned classify(const vnl_vector<double> &input) const;
00129 
00130   //: Provides a probability-like value that the input being in each class.
00131   // output(i) i<nClasses, contains the probability that the input is in class i
00132   virtual void class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const;
00133 
00134 
00135   //: This value has properties of a Log likelihood of being in class (binary classifiers only)
00136   // class probability = exp(logL) / (1+exp(logL))
00137   virtual double log_l(const vnl_vector<double> &input) const;
00138 
00139   //: The number of possible output classes.
00140   virtual unsigned n_classes() const {return 1;}
00141 
00142   //: The dimensionality of input vectors.
00143   virtual unsigned n_dims() const;
00144 
00145   //: Storage version number
00146   virtual short version_no() const;
00147 
00148   //: Name of the class
00149   virtual vcl_string is_a() const;
00150 
00151   //: Name of the class
00152   virtual bool is_class(vcl_string const& s) const;
00153 
00154   //: Create a copy on the heap and return base class pointer
00155   virtual clsfy_classifier_base* clone() const;
00156 
00157   //: Print class to os
00158   virtual void print_summary(vcl_ostream& os) const;
00159 
00160   //: Save class to binary file stream
00161   virtual void b_write(vsl_b_ostream& bfs) const;
00162 
00163   //: Load class from binary file stream
00164   virtual void b_read(vsl_b_istream& bfs);
00165 
00166   //: Normally only the builder uses this
00167   void set_root(  clsfy_binary_tree_node* root);
00168   private:
00169   clsfy_binary_tree_node* root_;
00170   mutable clsfy_binary_tree_node* cache_node_;
00171   private:
00172     void copy(const clsfy_binary_tree& srcTree);
00173     void copy_children(clsfy_binary_tree_node* pSrcNode,clsfy_binary_tree_node* pNode);
00174 
00175 
00176 };
00177 
00178 #endif // clsfy_binary_tree_h_