contrib/mul/mbl/mbl_lda.cxx
Go to the documentation of this file.
00001 // This is mul/mbl/mbl_lda.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 // \brief  Class to perform linear discriminant analysis
00008 // \author Tim Cootes
00009 //         Converted to VXL by Gavin Wheeler
00010 
00011 #include "mbl_lda.h"
00012 
00013 #include <vcl_algorithm.h>  // for vcl_find
00014 #include <vcl_cassert.h>
00015 #include <vcl_cstddef.h> // for size_t
00016 #include <vcl_cstring.h> // for memcpy()
00017 #include <vsl/vsl_indent.h>
00018 #include <vsl/vsl_vector_io.h>
00019 #include <vsl/vsl_binary_io.h>
00020 #include <vnl/algo/vnl_svd.h>
00021 #include <vnl/algo/vnl_symmetric_eigensystem.h>
00022 #include <vnl/algo/vnl_generalized_eigensystem.h>
00023 #include <vnl/io/vnl_io_vector.h>
00024 #include <mbl/mbl_matxvec.h>
00025 #include <mbl/mbl_log.h>
00026 #include <mbl/mbl_exception.h>
00027 
00028 
00029 //=========================================================================
00030 // Static function to create a static logger when first required
00031 //=========================================================================
00032 static mbl_logger& logger()
00033 {
00034   static mbl_logger l("mul.mbl.lda");
00035   return l;
00036 }
00037 
00038 
00039 //=======================================================================
00040 mbl_lda::mbl_lda()
00041 {
00042 }
00043 
00044 
00045 //=======================================================================
00046 mbl_lda::~mbl_lda()
00047 {
00048 }
00049 
00050 
00051 //=======================================================================
00052 //: Classify a new data point.
00053 // Projects into discriminant space and picks closest mean class vector
00054 int mbl_lda::classify( const vnl_vector<double>& x )
00055 {
00056   vnl_vector<double> d;
00057   x_to_d(d, x);
00058   int nc=n_classes();
00059   double min_d=(d-d_class_mean(0)).squared_magnitude();
00060   int min_i=0;
00061   for (int i=1; i<nc; ++i)
00062   {
00063     double dist=(d-d_class_mean(i)).squared_magnitude();
00064     if (dist<min_d ) { min_d= dist; min_i=i; }
00065   }
00066   return min_i;
00067 }
00068 
00069 
00070 //=======================================================================
00071 //: Comparison
00072 bool mbl_lda::operator==(const mbl_lda& that) const
00073 {
00074   return mean_ == that.mean_ &&
00075          d_mean_ == that.d_mean_ &&
00076          mean_class_mean_ == that.mean_class_mean_ &&
00077          n_samples_ == that.n_samples_ &&
00078          withinS_ == that.withinS_ &&
00079          betweenS_ == that.betweenS_ &&
00080          basis_ == that.basis_ &&
00081          evals_ == that.evals_ &&
00082          d_m_mean_ == that.d_m_mean_;
00083 }
00084 
00085 
00086 //=======================================================================
00087 void mbl_lda::updateCovar(vnl_matrix<double>& S, const vnl_vector<double>& V)
00088 {
00089   unsigned int n = V.size();
00090   if (S.rows()!=n)
00091   {
00092     S.set_size(n,n);
00093     S.fill(0);
00094   }
00095 
00096   double** s = S.data_array();
00097   const double* v = V.data_block();
00098   for (unsigned int i=0;i<n;++i)
00099   {
00100     double *row = s[i];
00101     double vi = v[i];
00102     for (unsigned int j=0;j<n;++j)
00103       row[j] += vi*v[j];
00104   }
00105 }
00106 
00107 
00108 //=======================================================================
00109 // find out how many id in the label vector
00110 int mbl_lda::nDistinctIDs(const int* id, const int n)
00111 {
00112   vcl_vector<int> dids;
00113   for (int i=0;i<n;++i)
00114   {
00115     if (vcl_find(dids.begin(), dids.end(), id[i])==dids.end())  // if (Index(dids,id[i])<0)
00116       dids.push_back(id[i]);
00117   }
00118 
00119   return dids.size();
00120 }
00121 
00122 
00123 //=======================================================================
00124 //: Perform LDA on data
00125 // \param label  Array [0..n-1] of integers indices
00126 // \param v  Set of vectors [0..n-1]
00127 //
00128 // label[i] gives class of v[i]
00129 // Classes must be labeled from 0 to m-1
00130 void mbl_lda::build(const vnl_vector<double>* v, const int * label, int n,
00131                     const vnl_matrix<double>& wS, bool compute_wS)
00132 {
00133   // Find range of class indices and count #valid
00134   int lo_i=label[0]; // =n causes failure if lo_i is less than n
00135   int hi_i=-1;
00136   int n_valid = 0;
00137   for (int i=0;i<n;++i)
00138   {
00139     if (label[i]>=0)
00140     {
00141       if (label[i]<lo_i) lo_i=label[i];
00142       if (label[i]>hi_i) hi_i=label[i];
00143       n_valid++;
00144     }
00145   }
00146 
00147   //  assert(lo_i==0);
00148 
00149   // Compute mean of each class
00150   int n_classes = nDistinctIDs(label,n);
00151   MBL_LOG(INFO, logger(), "There are " <<n_classes << " classes to build LDA space");
00152   MBL_LOG(INFO, logger(), "Max label index is " << hi_i);
00153   MBL_LOG(INFO, logger(), "Min label index is " << lo_i);
00154 
00155   int n_size=hi_i+1;
00156   mean_.resize(n_size);
00157   n_samples_.resize(n_size);
00158   for (int i=0;i<n_size;++i)
00159     n_samples_[i]=0;
00160 
00161   for (int i=0;i<n;++i)
00162   {
00163     int l = label[i];
00164     if (l<0) continue;
00165     if (mean_[l].size()==0)
00166     {
00167       mean_[l] = v[i];
00168       n_samples_[l] = 1;
00169     }
00170     else
00171     {
00172       mean_[l] += v[i];
00173       n_samples_[l] += 1;
00174     }
00175   }
00176 
00177   int n_used_classes = 0;
00178   for (int i=0;i<n_size;++i)
00179   {
00180     if (n_samples_[i]>0)
00181     {
00182       mean_[i]/=n_samples_[i];
00183       if (i==lo_i) mean_class_mean_ = mean_[i];
00184       else      mean_class_mean_ += mean_[i];
00185       n_used_classes++;
00186     }
00187   }
00188   MBL_LOG(INFO, logger(), "Number of used classes: " << n_used_classes);
00189 
00190   mean_class_mean_/=n_used_classes;
00191 
00192   // Build between class covariance
00193   // Zero to start:
00194   betweenS_.set_size(0,0);
00195 
00196   for (int i=0;i<n_size;++i)
00197   {
00198     if (n_samples_[i]>0)
00199       updateCovar(betweenS_,mean_[i] - mean_class_mean_);
00200   }
00201 
00202   betweenS_/=n_used_classes;
00203 
00204   if (compute_wS)
00205   {
00206     withinS_.set_size(0,0);
00207     // Count number of samples used to build matrix
00208     int n_used=0;
00209     for (int i=0;i<n;++i)
00210     {
00211       int l=label[i];
00212       if (l>=0 && n_samples_[l]>1)
00213       {
00214         updateCovar(withinS_,v[i]-mean_[l]);
00215         n_used++;
00216       }
00217     }
00218     withinS_/=n_used;
00219   }
00220   else
00221     withinS_ = wS;
00222 
00223 #if 0
00224   vnl_matrix<double> wS_inv;
00225   //  NR_Inverse(wS_inv,withinS_);
00226   vnl_svd<double> wS_svd(withinS_, -1.0e-10); // important!!! as the sigma_min=0.0
00227 
00228   wS_inv = wS_svd.inverse();
00229 
00230   vnl_matrix<double> B=withinS_*wS_inv;
00231   vcl_cout<<B<<vcl_endl;
00232 
00233   vnl_matrix<double> A = wS_inv* betweenS_; // was: betweenS_ * wS_inv;
00234 
00235   // Compute eigenvectors and eigenvalues (descending order)
00236   vnl_matrix<double> EVecs(A.rows(), A.columns());
00237   vnl_vector<double> evals(A.columns());
00238   //  NR_CalcSymEigens(A,EVecs,evals,false);
00239 
00240   // **** A not necessarily symmetric!!!! ****
00241   vnl_symmetric_eigensystem_compute(A, EVecs, evals);
00242 #endif // 0
00243 
00244   vnl_generalized_eigensystem gen_eigs(betweenS_,withinS_);
00245   vnl_matrix<double> EVecs= gen_eigs.V;
00246   vnl_vector<double> evals= gen_eigs.D.diagonal();
00247 
00248   // Log some information that might be helpful for debugging
00249   if (logger().level()>=mbl_logger::DEBUG)
00250   {
00251     MBL_LOG(DEBUG, logger(), "eigen decomp in original order:");
00252     unsigned nvec = EVecs.cols();
00253     for (unsigned i=0; i<nvec; ++i)
00254       MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00255               << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00256     for (unsigned i=0; i<nvec; ++i)
00257       MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00258   }
00259 
00260   // Re-arrange the eigenvector matrix (columns) and eigenvalue vector into descending order.
00261   // Assume they are in order of increasing eigenvalue magnitude.
00262   // NB The output from vnl_generalized_eigensystem above will be in order of
00263   // increasing (signed) eigenvalue, not magnitude. If we ever get negative eigenvalues,
00264   // then the simple reversal of flip() and fliplr() will not be correct.
00265   // Not sure whether we could get (significant) negative eigenvalues, but let's check.
00266   for (unsigned i=0; i<evals.size(); ++i)
00267   {
00268     if (evals[i]<-1e-12) // tolerance?
00269       throw mbl_exception_abort("mbl_lda::build(): found negative eigenvalue(s)");
00270   }
00271   evals.flip();
00272   EVecs.fliplr();
00273 
00274   // Log some information that might be helpful for debugging
00275   if (logger().level()>=mbl_logger::DEBUG)
00276   {
00277     MBL_LOG(DEBUG, logger(), "eigen decomp in sorted order:");
00278     unsigned nvec = EVecs.cols();
00279     for (unsigned i=0; i<nvec; ++i)
00280       MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00281               << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00282     for (unsigned i=0; i<nvec; ++i)
00283       MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00284   }
00285 
00286   // Record n_classes-1 vector basis
00287   int m = EVecs.rows();
00288   int t = n_used_classes-1;
00289   if (t>m) t=m;
00290 
00291   // Copy first t eigenvectors to basis_
00292   basis_.set_size(m,t);
00293   double **E = EVecs.data_array();
00294   double **b = basis_.data_array();
00295   vcl_size_t bytes_per_row = t * sizeof(double);
00296   for (int i=0;i<m;++i)
00297   {
00298     vcl_memcpy(b[i],E[i],bytes_per_row);
00299   }
00300 
00301   // Normalize the basis vectors
00302   MBL_LOG(DEBUG, logger(), "basis matrix before normalization:");
00303   basis_.print(logger().log(mbl_logger::DEBUG));
00304   basis_.normalize_columns();
00305   MBL_LOG(DEBUG, logger(), "basis matrix after normalization:");
00306   basis_.print(logger().log(mbl_logger::DEBUG));
00307 
00308   // Copy first t eigenvalues
00309   evals_.set_size(t);
00310   for (int i=0;i<t;++i)
00311     evals_[i] = evals[i];
00312 
00313   // Compute projection of mean into d space
00314   d_m_mean_.set_size(t);
00315   mbl_matxvec_prod_vm(mean_class_mean_,basis_,d_m_mean_);
00316 
00317   // Project each mean into d-space
00318   d_mean_.resize(n_size);
00319   for (int i=0;i<n_size;++i)
00320     if (n_samples_[i]>0)
00321       x_to_d(d_mean_[i],mean_[i]);
00322 }
00323 
00324 
00325 //=======================================================================
00326 //: Perform LDA on data
00327 void mbl_lda::build(const vnl_vector<double>* v, const int* label, int n)
00328 {
00329   build(v,label,n,vnl_matrix<double>(),true);
00330 }
00331 
00332 //=======================================================================
00333 //: Perform LDA on data
00334 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label)
00335 {
00336   build(v,&label.front(),label.size(),vnl_matrix<double>(),true);
00337 }
00338 
00339 //=======================================================================
00340 //: Perform LDA on data
00341 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label,
00342                     const vnl_matrix<double>& wS)
00343 {
00344   build(v,&label.front(),label.size(),wS,false);
00345 }
00346 
00347 //=======================================================================
00348 //: Perform LDA on data
00349 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label)
00350 {
00351   assert(v.size()==label.size());
00352   build(&v.front(),&label.front(),label.size(),vnl_matrix<double>(),true);
00353 }
00354 
00355 //=======================================================================
00356 //: Perform LDA on data
00357 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label,
00358                     const vnl_matrix<double>& wS)
00359 {
00360   assert(v.size()==label.size());
00361   build(&v.front(),&label.front(),label.size(),wS,false);
00362 }
00363 
00364 //=======================================================================
00365 //: Perform LDA on data
00366 //  Columns of M form example vectors
00367 //  i'th column belongs to class label[i]
00368 //  Note: label([1..n]) not label([0..n-1])
00369 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label)
00370 {
00371   unsigned int n_egs = M.columns();
00372   assert(n_egs==label.size());
00373   //  assert(label.lo()==1);
00374   vcl_vector<vnl_vector<double> > v(n_egs);
00375   for (unsigned int i=0;i<n_egs;++i)
00376   {
00377     v[i] = M.get_column(i);
00378   }
00379   build(&v.front(),&label.front(),n_egs,vnl_matrix<double>(),true);
00380 }
00381 
00382 //=======================================================================
00383 //: Perform LDA on data
00384 //  Columns of M form example vectors
00385 //  i'th column belongs to class label[i]
00386 //  Note: label([1..n]) not label([0..n-1])
00387 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label,
00388                     const vnl_matrix<double>& wS)
00389 {
00390   unsigned int n_egs = M.columns();
00391   assert(n_egs==label.size());
00392   //  assert(label.lo()==1);
00393   vcl_vector<vnl_vector<double> > v(n_egs);
00394   for (unsigned int i=0;i<n_egs;++i)
00395   {
00396     v[i] = M.get_column(i);
00397   }
00398   build(&v.front(),&label.front(),n_egs,wS,false);
00399 }
00400 
00401 
00402 //=======================================================================
00403 //: Project x into discriminant space
00404 void mbl_lda::x_to_d(vnl_vector<double>& d, const vnl_vector<double>& x) const
00405 {
00406   d.set_size(d_m_mean_.size());
00407   mbl_matxvec_prod_vm(x,basis_,d); // d = x' * M
00408   d-=d_m_mean_;
00409 }
00410 
00411 //=======================================================================
00412 //: Project d from discriminant space into original space
00413 void mbl_lda::d_to_x(vnl_vector<double>& x, const vnl_vector<double>& d) const
00414 {
00415   mbl_matxvec_prod_mv(basis_,d,x); // x = M * d
00416   x+=mean_class_mean_;
00417 }
00418 
00419 //=======================================================================
00420 
00421 short mbl_lda::version_no() const
00422 {
00423   return 1;
00424 }
00425 
00426 //=======================================================================
00427 
00428 vcl_string mbl_lda::is_a() const
00429 {
00430   return vcl_string("mbl_lda");
00431 }
00432 
00433 bool mbl_lda::is_class(vcl_string const& s) const
00434 {
00435   return s==is_a();
00436 }
00437 
00438 //=======================================================================
00439 
00440 void mbl_lda::print_summary(vcl_ostream& os) const
00441 {
00442   int n_classes= n_samples_.size();
00443   os << "n_classes= "<<n_classes<<'\n';
00444   for (int i=0; i<n_classes; ++i)
00445   {
00446     os <<"n_samples_["<<i<<"]= "<<n_samples_[i]<<'\n'
00447        <<"mean_["<<i<<"]= "<<mean_[i]<<'\n'
00448        <<"d_mean_["<<i<<"]= "<<d_mean_[i]<<'\n';
00449   }
00450 
00451   os << "withinS_= "<<withinS_<<'\n'
00452      << "betweenS_= "<<betweenS_<<'\n'
00453      << "basis_= "<<basis_<<'\n'
00454      << "evals_= "<<evals_<<'\n'
00455      << "d_m_mean_= "<<d_m_mean_<<'\n';
00456 }
00457 
00458 //=======================================================================
00459 
00460 void mbl_lda::b_write(vsl_b_ostream& bfs) const
00461 {
00462   vsl_b_write(bfs,version_no());
00463   vsl_b_write(bfs,mean_);
00464   vsl_b_write(bfs,d_mean_);
00465   vsl_b_write(bfs,mean_class_mean_);
00466   vsl_b_write(bfs,n_samples_);
00467   vsl_b_write(bfs,withinS_);
00468   vsl_b_write(bfs,betweenS_);
00469   vsl_b_write(bfs,basis_);
00470   vsl_b_write(bfs,evals_);
00471   vsl_b_write(bfs,d_m_mean_);
00472 }
00473 
00474 //=======================================================================
00475 
00476 void mbl_lda::b_read(vsl_b_istream& bfs)
00477 {
00478   if (!bfs) return;
00479 
00480   short version;
00481   vsl_b_read(bfs,version);
00482   switch (version)
00483   {
00484     case (1):
00485       vsl_b_read(bfs,mean_);
00486       vsl_b_read(bfs,d_mean_);
00487       vsl_b_read(bfs,mean_class_mean_);
00488       vsl_b_read(bfs,n_samples_);
00489       vsl_b_read(bfs,withinS_);
00490       vsl_b_read(bfs,betweenS_);
00491       vsl_b_read(bfs,basis_);
00492       vsl_b_read(bfs,evals_);
00493       vsl_b_read(bfs,d_m_mean_);
00494       break;
00495     default:
00496       // CHECK FUNCTION SIGNATURE IS CORRECT
00497       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_lda &)\n"
00498                << "           Unknown version number "<< version << vcl_endl;
00499       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00500       return;
00501   }
00502 }
00503 
00504 //=======================================================================
00505 
00506 void vsl_b_write(vsl_b_ostream& bfs, const mbl_lda& b)
00507 {
00508   b.b_write(bfs);
00509 }
00510 
00511 //=======================================================================
00512 
00513 void vsl_b_read(vsl_b_istream& bfs, mbl_lda& b)
00514 {
00515   b.b_read(bfs);
00516 }
00517 
00518 //=======================================================================
00519 
00520 vcl_ostream& operator<<(vcl_ostream& os,const mbl_lda& b)
00521 {
00522   os << b.is_a() << ": ";
00523   vsl_indent_inc(os);
00524   b.print_summary(os);
00525   vsl_indent_dec(os);
00526   return os;
00527 }