contrib/mul/mbl/mbl_rbf_network.h
Go to the documentation of this file.
00001 #ifndef mbl_rbf_network_h_
00002 #define mbl_rbf_network_h_
00003 //:
00004 // \file
00005 // \brief A class to perform some of the functions of a Radial Basis Function Network
00006 // \author tfc
00007 //         wondrous VXL conversion started by gvw, errors corrected by ...
00008 
00009 #include <vcl_string.h>
00010 #include <vcl_vector.h>
00011 #include <vcl_cmath.h>
00012 #include <vsl/vsl_binary_io.h>
00013 #include <vnl/io/vnl_io_vector.h>
00014 #include <vnl/io/vnl_io_matrix.h>
00015 #include <vnl/vnl_vector.h>
00016 
00017 //: A class to perform some of the functions of a Radial Basis Function Network.
00018 //  This is a special case of a mixture model pdf, where the same
00019 //  (radially symmetric) pdf kernel is used at each node.
00020 //  The nodes are supplied by build().
00021 //  calcWts(w,x) calculates the probabilities that x belongs to each
00022 //  node.
00023 //  Given a set of n training vectors, x_i (i=0..n-1), a set of internal weights are computed.
00024 //  Given a new vector, x, a vector of weights, w, are computed such that
00025 //  if x = x_i then w(i+1) = 1, w(j !=i+1) = 0  The sum of the weights
00026 //  should always be unity.
00027 //  If x is not equal to any training vector, the vector of weights varies
00028 //  smoothly.  This is useful for interpolation purposes.
00029 //  It can also be used to define non-linear transformations between
00030 //  vector spaces.  If Y is a matrix of n columns, each corresponding to
00031 //  a vector in a new space which corresponds to one of the original
00032 //  training vectors x_i, then a vector x can be mapped to Yw in the
00033 //  new space.  (Note: y-space does not have to have the same dimension
00034 //  as x space). This class is equivalent to
00035 //  the basis of thin-plate spline warping.
00036 //
00037 //  I'm not sure if this is exactly an RBF network in the original
00038 //  definition. I'll check one day.
00039 class mbl_rbf_network
00040 {
00041   vcl_vector<vnl_vector<double> > x_;
00042   vnl_matrix<double> W_;
00043   double s2_;
00044 
00045   bool sum_to_one_;
00046 
00047   //: workspace
00048   vnl_vector<double> v_;
00049 
00050   double distSqr(const vnl_vector<double>& x, const vnl_vector<double>& y) const;
00051   double rbf(double r2) const
00052     { return r2<=0.0 ? 1.0 : vcl_exp(-r2); }
00053 
00054   double rbf(const vnl_vector<double>& x, const vnl_vector<double>& y)
00055     { return rbf(distSqr(x,y)/s2_); }
00056 
00057  public:
00058 
00059   //: Dflt ctor
00060   mbl_rbf_network();
00061 
00062   //: Build weights given examples x.
00063   //  s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s)
00064   //  If s<=0 then a suitable s is estimated from the data
00065   void build(const vcl_vector<vnl_vector<double> >& x, double s = -1);
00066 
00067   //: Build weights given n examples x[0] to x[n-1].
00068   //  s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s)
00069   //  If s<=0 then a suitable s is estimated from the data
00070   void build(const vnl_vector<double>* x, int n, double s = -1);
00071 
00072   //: If true, then the returned weights sum to 1.0
00073   bool sumToOne() const { return sum_to_one_; }
00074 
00075   //: Set flag.  If false, calcWts returns raw weights
00076   void setSumToOne(bool flag);
00077 
00078   //: Array of training vectors x, supplied in last build()
00079   const vcl_vector<vnl_vector<double> >& x() const { return x_;}
00080 
00081   //: Compute weights for given new_x.
00082   //  If new_x = x()(i) then w(i+1)==1, w(j!=i+1)==0
00083   //  Otherwise w varies smoothly depending on distance
00084   //  of new_x from x()'s
00085   //  If sumToOne() then elements of w will sum to 1.0
00086   //  otherwise they will sum to <=1.0, decreasing as new_x
00087   //  moves away from the training examples x().
00088   void calcWts(vnl_vector<double>& w, const vnl_vector<double>& new_x);
00089 
00090   //: Version number for I/O
00091   short version_no() const;
00092 
00093   //: Name of the class
00094   vcl_string is_a() const;
00095 
00096   //: True if this is (or is derived from) class named s
00097   bool is_class(vcl_string const& s) const;
00098 
00099   //: Print class to os
00100   void print_summary(vcl_ostream& os) const;
00101 
00102   //: Save class to binary file stream
00103   void b_write(vsl_b_ostream& bfs) const;
00104 
00105   //: Load class from binary file stream
00106   void b_read(vsl_b_istream& bfs);
00107 };
00108 
00109 //: Binary file stream output operator for class reference
00110 void vsl_b_write(vsl_b_ostream& bfs, const mbl_rbf_network& b);
00111 
00112 //: Binary file stream input operator for class reference
00113 void vsl_b_read(vsl_b_istream& bfs, mbl_rbf_network& b);
00114 
00115 //: Stream output operator for class reference
00116 vcl_ostream& operator<<(vcl_ostream& os,const mbl_rbf_network& b);
00117 
00118 #endif //mbl_rbf_network_h_