00001 #ifndef mbl_rbf_network_h_ 00002 #define mbl_rbf_network_h_ 00003 //: 00004 // \file 00005 // \brief A class to perform some of the functions of a Radial Basis Function Network 00006 // \author tfc 00007 // wondrous VXL conversion started by gvw, errors corrected by ... 00008 00009 #include <vcl_string.h> 00010 #include <vcl_vector.h> 00011 #include <vcl_cmath.h> 00012 #include <vsl/vsl_binary_io.h> 00013 #include <vnl/io/vnl_io_vector.h> 00014 #include <vnl/io/vnl_io_matrix.h> 00015 #include <vnl/vnl_vector.h> 00016 00017 //: A class to perform some of the functions of a Radial Basis Function Network. 00018 // This is a special case of a mixture model pdf, where the same 00019 // (radially symmetric) pdf kernel is used at each node. 00020 // The nodes are supplied by build(). 00021 // calcWts(w,x) calculates the probabilities that x belongs to each 00022 // node. 00023 // Given a set of n training vectors, x_i (i=0..n-1), a set of internal weights are computed. 00024 // Given a new vector, x, a vector of weights, w, are computed such that 00025 // if x = x_i then w(i+1) = 1, w(j !=i+1) = 0 The sum of the weights 00026 // should always be unity. 00027 // If x is not equal to any training vector, the vector of weights varies 00028 // smoothly. This is useful for interpolation purposes. 00029 // It can also be used to define non-linear transformations between 00030 // vector spaces. If Y is a matrix of n columns, each corresponding to 00031 // a vector in a new space which corresponds to one of the original 00032 // training vectors x_i, then a vector x can be mapped to Yw in the 00033 // new space. (Note: y-space does not have to have the same dimension 00034 // as x space). This class is equivalent to 00035 // the basis of thin-plate spline warping. 00036 // 00037 // I'm not sure if this is exactly an RBF network in the original 00038 // definition. I'll check one day. 00039 class mbl_rbf_network 00040 { 00041 vcl_vector<vnl_vector<double> > x_; 00042 vnl_matrix<double> W_; 00043 double s2_; 00044 00045 bool sum_to_one_; 00046 00047 //: workspace 00048 vnl_vector<double> v_; 00049 00050 double distSqr(const vnl_vector<double>& x, const vnl_vector<double>& y) const; 00051 double rbf(double r2) const 00052 { return r2<=0.0 ? 1.0 : vcl_exp(-r2); } 00053 00054 double rbf(const vnl_vector<double>& x, const vnl_vector<double>& y) 00055 { return rbf(distSqr(x,y)/s2_); } 00056 00057 public: 00058 00059 //: Dflt ctor 00060 mbl_rbf_network(); 00061 00062 //: Build weights given examples x. 00063 // s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s) 00064 // If s<=0 then a suitable s is estimated from the data 00065 void build(const vcl_vector<vnl_vector<double> >& x, double s = -1); 00066 00067 //: Build weights given n examples x[0] to x[n-1]. 00068 // s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s) 00069 // If s<=0 then a suitable s is estimated from the data 00070 void build(const vnl_vector<double>* x, int n, double s = -1); 00071 00072 //: If true, then the returned weights sum to 1.0 00073 bool sumToOne() const { return sum_to_one_; } 00074 00075 //: Set flag. If false, calcWts returns raw weights 00076 void setSumToOne(bool flag); 00077 00078 //: Array of training vectors x, supplied in last build() 00079 const vcl_vector<vnl_vector<double> >& x() const { return x_;} 00080 00081 //: Compute weights for given new_x. 00082 // If new_x = x()(i) then w(i+1)==1, w(j!=i+1)==0 00083 // Otherwise w varies smoothly depending on distance 00084 // of new_x from x()'s 00085 // If sumToOne() then elements of w will sum to 1.0 00086 // otherwise they will sum to <=1.0, decreasing as new_x 00087 // moves away from the training examples x(). 00088 void calcWts(vnl_vector<double>& w, const vnl_vector<double>& new_x); 00089 00090 //: Version number for I/O 00091 short version_no() const; 00092 00093 //: Name of the class 00094 vcl_string is_a() const; 00095 00096 //: True if this is (or is derived from) class named s 00097 bool is_class(vcl_string const& s) const; 00098 00099 //: Print class to os 00100 void print_summary(vcl_ostream& os) const; 00101 00102 //: Save class to binary file stream 00103 void b_write(vsl_b_ostream& bfs) const; 00104 00105 //: Load class from binary file stream 00106 void b_read(vsl_b_istream& bfs); 00107 }; 00108 00109 //: Binary file stream output operator for class reference 00110 void vsl_b_write(vsl_b_ostream& bfs, const mbl_rbf_network& b); 00111 00112 //: Binary file stream input operator for class reference 00113 void vsl_b_read(vsl_b_istream& bfs, mbl_rbf_network& b); 00114 00115 //: Stream output operator for class reference 00116 vcl_ostream& operator<<(vcl_ostream& os,const mbl_rbf_network& b); 00117 00118 #endif //mbl_rbf_network_h_
1.7.5.1