00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010
00011 #include "mbl_lda.h"
00012
00013 #include <vcl_algorithm.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cstddef.h>
00016 #include <vcl_cstring.h>
00017 #include <vsl/vsl_indent.h>
00018 #include <vsl/vsl_vector_io.h>
00019 #include <vsl/vsl_binary_io.h>
00020 #include <vnl/algo/vnl_svd.h>
00021 #include <vnl/algo/vnl_symmetric_eigensystem.h>
00022 #include <vnl/algo/vnl_generalized_eigensystem.h>
00023 #include <vnl/io/vnl_io_vector.h>
00024 #include <mbl/mbl_matxvec.h>
00025 #include <mbl/mbl_log.h>
00026 #include <mbl/mbl_exception.h>
00027
00028
00029
00030
00031
00032 static mbl_logger& logger()
00033 {
00034 static mbl_logger l("mul.mbl.lda");
00035 return l;
00036 }
00037
00038
00039
00040 mbl_lda::mbl_lda()
00041 {
00042 }
00043
00044
00045
00046 mbl_lda::~mbl_lda()
00047 {
00048 }
00049
00050
00051
00052
00053
00054 int mbl_lda::classify( const vnl_vector<double>& x )
00055 {
00056 vnl_vector<double> d;
00057 x_to_d(d, x);
00058 int nc=n_classes();
00059 double min_d=(d-d_class_mean(0)).squared_magnitude();
00060 int min_i=0;
00061 for (int i=1; i<nc; ++i)
00062 {
00063 double dist=(d-d_class_mean(i)).squared_magnitude();
00064 if (dist<min_d ) { min_d= dist; min_i=i; }
00065 }
00066 return min_i;
00067 }
00068
00069
00070
00071
00072 bool mbl_lda::operator==(const mbl_lda& that) const
00073 {
00074 return mean_ == that.mean_ &&
00075 d_mean_ == that.d_mean_ &&
00076 mean_class_mean_ == that.mean_class_mean_ &&
00077 n_samples_ == that.n_samples_ &&
00078 withinS_ == that.withinS_ &&
00079 betweenS_ == that.betweenS_ &&
00080 basis_ == that.basis_ &&
00081 evals_ == that.evals_ &&
00082 d_m_mean_ == that.d_m_mean_;
00083 }
00084
00085
00086
00087 void mbl_lda::updateCovar(vnl_matrix<double>& S, const vnl_vector<double>& V)
00088 {
00089 unsigned int n = V.size();
00090 if (S.rows()!=n)
00091 {
00092 S.set_size(n,n);
00093 S.fill(0);
00094 }
00095
00096 double** s = S.data_array();
00097 const double* v = V.data_block();
00098 for (unsigned int i=0;i<n;++i)
00099 {
00100 double *row = s[i];
00101 double vi = v[i];
00102 for (unsigned int j=0;j<n;++j)
00103 row[j] += vi*v[j];
00104 }
00105 }
00106
00107
00108
00109
00110 int mbl_lda::nDistinctIDs(const int* id, const int n)
00111 {
00112 vcl_vector<int> dids;
00113 for (int i=0;i<n;++i)
00114 {
00115 if (vcl_find(dids.begin(), dids.end(), id[i])==dids.end())
00116 dids.push_back(id[i]);
00117 }
00118
00119 return dids.size();
00120 }
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130 void mbl_lda::build(const vnl_vector<double>* v, const int * label, int n,
00131 const vnl_matrix<double>& wS, bool compute_wS)
00132 {
00133
00134 int lo_i=label[0];
00135 int hi_i=-1;
00136 int n_valid = 0;
00137 for (int i=0;i<n;++i)
00138 {
00139 if (label[i]>=0)
00140 {
00141 if (label[i]<lo_i) lo_i=label[i];
00142 if (label[i]>hi_i) hi_i=label[i];
00143 n_valid++;
00144 }
00145 }
00146
00147
00148
00149
00150 int n_classes = nDistinctIDs(label,n);
00151 MBL_LOG(INFO, logger(), "There are " <<n_classes << " classes to build LDA space");
00152 MBL_LOG(INFO, logger(), "Max label index is " << hi_i);
00153 MBL_LOG(INFO, logger(), "Min label index is " << lo_i);
00154
00155 int n_size=hi_i+1;
00156 mean_.resize(n_size);
00157 n_samples_.resize(n_size);
00158 for (int i=0;i<n_size;++i)
00159 n_samples_[i]=0;
00160
00161 for (int i=0;i<n;++i)
00162 {
00163 int l = label[i];
00164 if (l<0) continue;
00165 if (mean_[l].size()==0)
00166 {
00167 mean_[l] = v[i];
00168 n_samples_[l] = 1;
00169 }
00170 else
00171 {
00172 mean_[l] += v[i];
00173 n_samples_[l] += 1;
00174 }
00175 }
00176
00177 int n_used_classes = 0;
00178 for (int i=0;i<n_size;++i)
00179 {
00180 if (n_samples_[i]>0)
00181 {
00182 mean_[i]/=n_samples_[i];
00183 if (i==lo_i) mean_class_mean_ = mean_[i];
00184 else mean_class_mean_ += mean_[i];
00185 n_used_classes++;
00186 }
00187 }
00188 MBL_LOG(INFO, logger(), "Number of used classes: " << n_used_classes);
00189
00190 mean_class_mean_/=n_used_classes;
00191
00192
00193
00194 betweenS_.set_size(0,0);
00195
00196 for (int i=0;i<n_size;++i)
00197 {
00198 if (n_samples_[i]>0)
00199 updateCovar(betweenS_,mean_[i] - mean_class_mean_);
00200 }
00201
00202 betweenS_/=n_used_classes;
00203
00204 if (compute_wS)
00205 {
00206 withinS_.set_size(0,0);
00207
00208 int n_used=0;
00209 for (int i=0;i<n;++i)
00210 {
00211 int l=label[i];
00212 if (l>=0 && n_samples_[l]>1)
00213 {
00214 updateCovar(withinS_,v[i]-mean_[l]);
00215 n_used++;
00216 }
00217 }
00218 withinS_/=n_used;
00219 }
00220 else
00221 withinS_ = wS;
00222
00223 #if 0
00224 vnl_matrix<double> wS_inv;
00225
00226 vnl_svd<double> wS_svd(withinS_, -1.0e-10);
00227
00228 wS_inv = wS_svd.inverse();
00229
00230 vnl_matrix<double> B=withinS_*wS_inv;
00231 vcl_cout<<B<<vcl_endl;
00232
00233 vnl_matrix<double> A = wS_inv* betweenS_;
00234
00235
00236 vnl_matrix<double> EVecs(A.rows(), A.columns());
00237 vnl_vector<double> evals(A.columns());
00238
00239
00240
00241 vnl_symmetric_eigensystem_compute(A, EVecs, evals);
00242 #endif // 0
00243
00244 vnl_generalized_eigensystem gen_eigs(betweenS_,withinS_);
00245 vnl_matrix<double> EVecs= gen_eigs.V;
00246 vnl_vector<double> evals= gen_eigs.D.diagonal();
00247
00248
00249 if (logger().level()>=mbl_logger::DEBUG)
00250 {
00251 MBL_LOG(DEBUG, logger(), "eigen decomp in original order:");
00252 unsigned nvec = EVecs.cols();
00253 for (unsigned i=0; i<nvec; ++i)
00254 MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00255 << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00256 for (unsigned i=0; i<nvec; ++i)
00257 MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00258 }
00259
00260
00261
00262
00263
00264
00265
00266 for (unsigned i=0; i<evals.size(); ++i)
00267 {
00268 if (evals[i]<-1e-12)
00269 throw mbl_exception_abort("mbl_lda::build(): found negative eigenvalue(s)");
00270 }
00271 evals.flip();
00272 EVecs.fliplr();
00273
00274
00275 if (logger().level()>=mbl_logger::DEBUG)
00276 {
00277 MBL_LOG(DEBUG, logger(), "eigen decomp in sorted order:");
00278 unsigned nvec = EVecs.cols();
00279 for (unsigned i=0; i<nvec; ++i)
00280 MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00281 << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00282 for (unsigned i=0; i<nvec; ++i)
00283 MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00284 }
00285
00286
00287 int m = EVecs.rows();
00288 int t = n_used_classes-1;
00289 if (t>m) t=m;
00290
00291
00292 basis_.set_size(m,t);
00293 double **E = EVecs.data_array();
00294 double **b = basis_.data_array();
00295 vcl_size_t bytes_per_row = t * sizeof(double);
00296 for (int i=0;i<m;++i)
00297 {
00298 vcl_memcpy(b[i],E[i],bytes_per_row);
00299 }
00300
00301
00302 MBL_LOG(DEBUG, logger(), "basis matrix before normalization:");
00303 basis_.print(logger().log(mbl_logger::DEBUG));
00304 basis_.normalize_columns();
00305 MBL_LOG(DEBUG, logger(), "basis matrix after normalization:");
00306 basis_.print(logger().log(mbl_logger::DEBUG));
00307
00308
00309 evals_.set_size(t);
00310 for (int i=0;i<t;++i)
00311 evals_[i] = evals[i];
00312
00313
00314 d_m_mean_.set_size(t);
00315 mbl_matxvec_prod_vm(mean_class_mean_,basis_,d_m_mean_);
00316
00317
00318 d_mean_.resize(n_size);
00319 for (int i=0;i<n_size;++i)
00320 if (n_samples_[i]>0)
00321 x_to_d(d_mean_[i],mean_[i]);
00322 }
00323
00324
00325
00326
00327 void mbl_lda::build(const vnl_vector<double>* v, const int* label, int n)
00328 {
00329 build(v,label,n,vnl_matrix<double>(),true);
00330 }
00331
00332
00333
00334 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label)
00335 {
00336 build(v,&label.front(),label.size(),vnl_matrix<double>(),true);
00337 }
00338
00339
00340
00341 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label,
00342 const vnl_matrix<double>& wS)
00343 {
00344 build(v,&label.front(),label.size(),wS,false);
00345 }
00346
00347
00348
00349 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label)
00350 {
00351 assert(v.size()==label.size());
00352 build(&v.front(),&label.front(),label.size(),vnl_matrix<double>(),true);
00353 }
00354
00355
00356
00357 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label,
00358 const vnl_matrix<double>& wS)
00359 {
00360 assert(v.size()==label.size());
00361 build(&v.front(),&label.front(),label.size(),wS,false);
00362 }
00363
00364
00365
00366
00367
00368
00369 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label)
00370 {
00371 unsigned int n_egs = M.columns();
00372 assert(n_egs==label.size());
00373
00374 vcl_vector<vnl_vector<double> > v(n_egs);
00375 for (unsigned int i=0;i<n_egs;++i)
00376 {
00377 v[i] = M.get_column(i);
00378 }
00379 build(&v.front(),&label.front(),n_egs,vnl_matrix<double>(),true);
00380 }
00381
00382
00383
00384
00385
00386
00387 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label,
00388 const vnl_matrix<double>& wS)
00389 {
00390 unsigned int n_egs = M.columns();
00391 assert(n_egs==label.size());
00392
00393 vcl_vector<vnl_vector<double> > v(n_egs);
00394 for (unsigned int i=0;i<n_egs;++i)
00395 {
00396 v[i] = M.get_column(i);
00397 }
00398 build(&v.front(),&label.front(),n_egs,wS,false);
00399 }
00400
00401
00402
00403
00404 void mbl_lda::x_to_d(vnl_vector<double>& d, const vnl_vector<double>& x) const
00405 {
00406 d.set_size(d_m_mean_.size());
00407 mbl_matxvec_prod_vm(x,basis_,d);
00408 d-=d_m_mean_;
00409 }
00410
00411
00412
00413 void mbl_lda::d_to_x(vnl_vector<double>& x, const vnl_vector<double>& d) const
00414 {
00415 mbl_matxvec_prod_mv(basis_,d,x);
00416 x+=mean_class_mean_;
00417 }
00418
00419
00420
00421 short mbl_lda::version_no() const
00422 {
00423 return 1;
00424 }
00425
00426
00427
00428 vcl_string mbl_lda::is_a() const
00429 {
00430 return vcl_string("mbl_lda");
00431 }
00432
00433 bool mbl_lda::is_class(vcl_string const& s) const
00434 {
00435 return s==is_a();
00436 }
00437
00438
00439
00440 void mbl_lda::print_summary(vcl_ostream& os) const
00441 {
00442 int n_classes= n_samples_.size();
00443 os << "n_classes= "<<n_classes<<'\n';
00444 for (int i=0; i<n_classes; ++i)
00445 {
00446 os <<"n_samples_["<<i<<"]= "<<n_samples_[i]<<'\n'
00447 <<"mean_["<<i<<"]= "<<mean_[i]<<'\n'
00448 <<"d_mean_["<<i<<"]= "<<d_mean_[i]<<'\n';
00449 }
00450
00451 os << "withinS_= "<<withinS_<<'\n'
00452 << "betweenS_= "<<betweenS_<<'\n'
00453 << "basis_= "<<basis_<<'\n'
00454 << "evals_= "<<evals_<<'\n'
00455 << "d_m_mean_= "<<d_m_mean_<<'\n';
00456 }
00457
00458
00459
00460 void mbl_lda::b_write(vsl_b_ostream& bfs) const
00461 {
00462 vsl_b_write(bfs,version_no());
00463 vsl_b_write(bfs,mean_);
00464 vsl_b_write(bfs,d_mean_);
00465 vsl_b_write(bfs,mean_class_mean_);
00466 vsl_b_write(bfs,n_samples_);
00467 vsl_b_write(bfs,withinS_);
00468 vsl_b_write(bfs,betweenS_);
00469 vsl_b_write(bfs,basis_);
00470 vsl_b_write(bfs,evals_);
00471 vsl_b_write(bfs,d_m_mean_);
00472 }
00473
00474
00475
00476 void mbl_lda::b_read(vsl_b_istream& bfs)
00477 {
00478 if (!bfs) return;
00479
00480 short version;
00481 vsl_b_read(bfs,version);
00482 switch (version)
00483 {
00484 case (1):
00485 vsl_b_read(bfs,mean_);
00486 vsl_b_read(bfs,d_mean_);
00487 vsl_b_read(bfs,mean_class_mean_);
00488 vsl_b_read(bfs,n_samples_);
00489 vsl_b_read(bfs,withinS_);
00490 vsl_b_read(bfs,betweenS_);
00491 vsl_b_read(bfs,basis_);
00492 vsl_b_read(bfs,evals_);
00493 vsl_b_read(bfs,d_m_mean_);
00494 break;
00495 default:
00496
00497 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_lda &)\n"
00498 << " Unknown version number "<< version << vcl_endl;
00499 bfs.is().clear(vcl_ios::badbit);
00500 return;
00501 }
00502 }
00503
00504
00505
00506 void vsl_b_write(vsl_b_ostream& bfs, const mbl_lda& b)
00507 {
00508 b.b_write(bfs);
00509 }
00510
00511
00512
00513 void vsl_b_read(vsl_b_istream& bfs, mbl_lda& b)
00514 {
00515 b.b_read(bfs);
00516 }
00517
00518
00519
00520 vcl_ostream& operator<<(vcl_ostream& os,const mbl_lda& b)
00521 {
00522 os << b.is_a() << ": ";
00523 vsl_indent_inc(os);
00524 b.print_summary(os);
00525 vsl_indent_dec(os);
00526 return os;
00527 }