Go to the documentation of this file.00001
00002
00003 #include "clsfy_rbf_svm_smo_1_builder.h"
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <vcl_string.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_algorithm.h>
00015 #include <vcl_cassert.h>
00016 #include <vul/vul_string.h>
00017
00018 #include <mbl/mbl_data_wrapper.h>
00019 #include <mbl/mbl_parse_block.h>
00020 #include <mbl/mbl_read_props.h>
00021
00022 #include <clsfy/clsfy_smo_1.h>
00023
00024
00025
00026 inline int class_to_svm_target (unsigned v) {return v==1?1:-1;}
00027
00028
00029
00030
00031 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00032 mbl_data_wrapper<vnl_vector<double> >& inputs,
00033 const vcl_vector<unsigned> &outputs) const
00034 {
00035 inputs.reset();
00036
00037 const unsigned int nSamples = inputs.size();
00038 assert(outputs.size() == nSamples);
00039 assert(*vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00040
00041 assert(classifier.is_class("clsfy_rbf_svm"));
00042 clsfy_rbf_svm &svm = static_cast<clsfy_rbf_svm &>(classifier);
00043
00044 clsfy_smo_1_rbf svAPI;
00045 vcl_vector<int> targets(nSamples);
00046 vcl_transform(outputs.begin(), outputs.end(),
00047 targets.begin(), class_to_svm_target);
00048
00049 svAPI.set_data(inputs, targets);
00050
00051
00052
00053 svAPI.set_C(boundC_);
00054 svAPI.set_gamma(1.0/(2.0*rbf_width_*rbf_width_));
00055
00056 svAPI.calc();
00057
00058
00059
00060 {
00061 vcl_vector<vnl_vector<double> > supportVectors;
00062 const vnl_vector<double> &allAlphas = svAPI.lagrange_mults();
00063 vcl_vector<double> alphas;
00064 vcl_vector<unsigned> labels;
00065 for (unsigned i=0; i<nSamples; ++i)
00066 if (allAlphas[i]!=0.0)
00067 {
00068 alphas.push_back(allAlphas[i]);
00069 labels.push_back(outputs[i]);
00070 inputs.set_index(i);
00071 supportVectors.push_back(inputs.current());
00072 }
00073 svm.set(supportVectors, alphas, labels, rbf_width_, svAPI.bias());
00074 }
00075
00076 return svAPI.error_rate();
00077 }
00078
00079
00080
00081
00082
00083 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00084 mbl_data_wrapper<vnl_vector<double> >& inputs,
00085 unsigned nClasses,
00086 const vcl_vector<unsigned> &outputs) const
00087 {
00088 assert(nClasses == 1);
00089 return build(classifier, inputs, outputs);
00090 }
00091
00092
00093
00094 double clsfy_rbf_svm_smo_1_builder::rbf_width() const
00095 {
00096 return rbf_width_;
00097 }
00098
00099
00100
00101 void clsfy_rbf_svm_smo_1_builder::set_rbf_width(double rbf_width)
00102 {
00103 rbf_width_ = rbf_width;
00104 }
00105
00106
00107 vcl_string clsfy_rbf_svm_smo_1_builder::is_a() const
00108 {
00109 return vcl_string("clsfy_rbf_svm_smo_1_builder");
00110 }
00111
00112
00113
00114 bool clsfy_rbf_svm_smo_1_builder::is_class(vcl_string const& s) const
00115 {
00116 return s == clsfy_rbf_svm_smo_1_builder::is_a() || clsfy_builder_base::is_class(s);
00117 }
00118
00119
00120
00121 short clsfy_rbf_svm_smo_1_builder::version_no() const
00122 {
00123 return 1;
00124 }
00125
00126
00127
00128 clsfy_builder_base* clsfy_rbf_svm_smo_1_builder::clone() const
00129 {
00130 return new clsfy_rbf_svm_smo_1_builder(*this);
00131 }
00132
00133
00134
00135 void clsfy_rbf_svm_smo_1_builder::print_summary(vcl_ostream& os) const
00136 {
00137
00138 os << "RBF width = " << rbf_width_ << ", bounds = " << boundC_;
00139 }
00140
00141
00142
00143 void clsfy_rbf_svm_smo_1_builder::b_write(vsl_b_ostream& bfs) const
00144 {
00145 vsl_b_write(bfs,version_no());
00146 vsl_b_write(bfs,boundC_);
00147 vsl_b_write(bfs,rbf_width_);
00148 }
00149
00150
00151
00152 void clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream& bfs)
00153 {
00154 if (!bfs) return;
00155
00156 short version;
00157 vsl_b_read(bfs,version);
00158 switch (version)
00159 {
00160 case (1):
00161 vsl_b_read(bfs,boundC_);
00162 vsl_b_read(bfs,rbf_width_);
00163 break;
00164 default:
00165 vcl_cerr << "I/O ERROR: clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream&)\n"
00166 << " Unknown version number "<< version << '\n';
00167 bfs.is().clear(vcl_ios::badbit);
00168 return;
00169 }
00170 }
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186 void clsfy_rbf_svm_smo_1_builder::config(vcl_istream &as)
00187 {
00188 vcl_string s = mbl_parse_block(as);
00189
00190 vcl_istringstream ss(s);
00191 mbl_read_props_type props = mbl_read_props_ws(ss);
00192
00193 {
00194 boundC_= vul_string_atof(props.get_optional_property("boundC", "0.0"));
00195 rbf_width_= vul_string_atof(props.get_optional_property("rbf_width", "0.0"));
00196 }
00197
00198
00199 mbl_read_props_look_for_unused_props(
00200 "clsfy_rbf_svm_smo_1_builder::config", props, mbl_read_props_type());
00201 }