contrib/mul/clsfy/clsfy_rbf_svm_smo_1_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_rbf_svm_smo_1_builder.cxx
00002 // Copyright: (C) 2001 British Telecommunications plc.
00003 #include "clsfy_rbf_svm_smo_1_builder.h"
00004 //:
00005 // \file
00006 // \brief Implement an interface to SMO algorithm SVM builder and additional logic
00007 // \author Ian Scott
00008 // \date Dec 2001
00009 
00010 //=======================================================================
00011 
00012 #include <vcl_string.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_algorithm.h>
00015 #include <vcl_cassert.h>
00016 #include <vul/vul_string.h>
00017 
00018 #include <mbl/mbl_data_wrapper.h>
00019 #include <mbl/mbl_parse_block.h>
00020 #include <mbl/mbl_read_props.h>
00021 
00022 #include <clsfy/clsfy_smo_1.h>
00023 
00024 //=======================================================================
00025 
00026 inline int class_to_svm_target (unsigned v) {return v==1?1:-1;}
00027 
00028 //=======================================================================
00029 //: Build classifier from data
00030 // returns the training error, or +INF if there is an error.
00031 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00032                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00033                                           const vcl_vector<unsigned> &outputs) const
00034 {
00035   inputs.reset();
00036 //const unsigned int nDims = inputs.current().size(); // unused variable
00037   const unsigned int nSamples = inputs.size();
00038   assert(outputs.size() == nSamples);
00039   assert(*vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00040 
00041   assert(classifier.is_class("clsfy_rbf_svm"));
00042   clsfy_rbf_svm &svm = static_cast<clsfy_rbf_svm &>(classifier);
00043 
00044   clsfy_smo_1_rbf svAPI;
00045   vcl_vector<int> targets(nSamples);
00046   vcl_transform(outputs.begin(), outputs.end(),
00047                 targets.begin(), class_to_svm_target);
00048 
00049   svAPI.set_data(inputs, targets);
00050 
00051 
00052   // Set the SVM solver parameters
00053   svAPI.set_C(boundC_);
00054   svAPI.set_gamma(1.0/(2.0*rbf_width_*rbf_width_));
00055   // Solve the SVM
00056   svAPI.calc();
00057 
00058 
00059   // Get the SVM description, and build an SVM machine
00060   {
00061     vcl_vector<vnl_vector<double> > supportVectors;
00062     const vnl_vector<double> &allAlphas = svAPI.lagrange_mults();
00063     vcl_vector<double> alphas;
00064     vcl_vector<unsigned> labels;
00065     for (unsigned i=0; i<nSamples; ++i)
00066       if (allAlphas[i]!=0.0)
00067       {
00068         alphas.push_back(allAlphas[i]);
00069         labels.push_back(outputs[i]);
00070         inputs.set_index(i);
00071         supportVectors.push_back(inputs.current());
00072       }
00073     svm.set(supportVectors, alphas, labels, rbf_width_, svAPI.bias());
00074   }
00075 
00076   return svAPI.error_rate();
00077 }
00078 
00079 //=======================================================================
00080 //: Build classifier from data.
00081 // returns the training error, or +INF if there is an error.
00082 // nClasses must be 1.
00083 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00084                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00085                                           unsigned nClasses,
00086                                           const vcl_vector<unsigned> &outputs) const
00087 {
00088   assert(nClasses == 1);
00089   return build(classifier, inputs, outputs);
00090 }
00091 
00092 //=======================================================================
00093 
00094 double clsfy_rbf_svm_smo_1_builder::rbf_width() const
00095 {
00096   return rbf_width_;
00097 }
00098 
00099 //=======================================================================
00100 
00101 void clsfy_rbf_svm_smo_1_builder::set_rbf_width(double rbf_width)
00102 {
00103   rbf_width_ = rbf_width;
00104 }
00105 //=======================================================================
00106 
00107 vcl_string clsfy_rbf_svm_smo_1_builder::is_a() const
00108 {
00109   return vcl_string("clsfy_rbf_svm_smo_1_builder");
00110 }
00111 
00112 //=======================================================================
00113 
00114 bool clsfy_rbf_svm_smo_1_builder::is_class(vcl_string const& s) const
00115 {
00116   return s == clsfy_rbf_svm_smo_1_builder::is_a() || clsfy_builder_base::is_class(s);
00117 }
00118 
00119 //=======================================================================
00120 
00121 short clsfy_rbf_svm_smo_1_builder::version_no() const
00122 {
00123   return 1;
00124 }
00125 
00126 //=======================================================================
00127 
00128 clsfy_builder_base* clsfy_rbf_svm_smo_1_builder::clone() const
00129 {
00130   return new clsfy_rbf_svm_smo_1_builder(*this);
00131 }
00132 
00133 //=======================================================================
00134 
00135 void clsfy_rbf_svm_smo_1_builder::print_summary(vcl_ostream& os) const
00136 {
00137   // os << data_; // example of data output
00138   os << "RBF width = " << rbf_width_ << ", bounds = " << boundC_;
00139 }
00140 
00141 //=======================================================================
00142 
00143 void clsfy_rbf_svm_smo_1_builder::b_write(vsl_b_ostream& bfs) const
00144 {
00145   vsl_b_write(bfs,version_no());
00146   vsl_b_write(bfs,boundC_);
00147   vsl_b_write(bfs,rbf_width_);
00148 }
00149 
00150 //=======================================================================
00151 
00152 void clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream& bfs)
00153 {
00154   if (!bfs) return;
00155 
00156   short version;
00157   vsl_b_read(bfs,version);
00158   switch (version)
00159   {
00160   case (1):
00161     vsl_b_read(bfs,boundC_);
00162     vsl_b_read(bfs,rbf_width_);
00163     break;
00164   default:
00165     vcl_cerr << "I/O ERROR: clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream&)\n"
00166              << "           Unknown version number "<< version << '\n';
00167     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00168     return;
00169   }
00170 }
00171 
00172 
00173 //=======================================================================
00174 //: Initialise the parameters from a text stream.
00175 // The next non-ws character in the stream should be a '{'
00176 // \verbatim
00177 // {
00178 //   boundC: 3  (default 0 meaning no bound) Upper bound on the Lagrange multiplies.
00179 //              Smaller non-zero values result in a osftening of the boundary.
00180 //
00181 //   rbf_width: 3.0  (required) - A good guess is the mean euclidean distance
00182 //                    to every examples nearest neighbour.
00183 // }
00184 // \endverbatim
00185 // \throw mbl_exception_parse_error if the parse fails.
00186 void clsfy_rbf_svm_smo_1_builder::config(vcl_istream &as)
00187 {
00188  vcl_string s = mbl_parse_block(as);
00189 
00190   vcl_istringstream ss(s);
00191   mbl_read_props_type props = mbl_read_props_ws(ss);
00192 
00193   {
00194     boundC_= vul_string_atof(props.get_optional_property("boundC", "0.0"));
00195     rbf_width_= vul_string_atof(props.get_optional_property("rbf_width", "0.0"));
00196   }
00197 
00198   // Check for unused props
00199   mbl_read_props_look_for_unused_props(
00200     "clsfy_rbf_svm_smo_1_builder::config", props, mbl_read_props_type());
00201 }