contrib/mul/vpdfl/vpdfl_mixture_builder.cxx
Go to the documentation of this file.
00001 // This is mul/vpdfl/vpdfl_mixture_builder.cxx
00002 //=======================================================================
00003 //
00004 //  Copyright: (C) 2000 Victoria University of Manchester
00005 //
00006 //=======================================================================
00007 #include "vpdfl_mixture_builder.h"
00008 //:
00009 // \file
00010 // \brief Implements builder for a mixture model PDF.
00011 // \author Tim Cootes
00012 // \date 21-July-98
00013 //
00014 // Modifications
00015 // \verbatim
00016 //    IMS   Converted to VXL 14 May 2000, with redesign
00017 // \endverbatim
00018 
00019 #include <vcl_cassert.h>
00020 #include <vcl_cmath.h>
00021 #include <vcl_cstdlib.h> // for vcl_abort()
00022 #include <vsl/vsl_indent.h>
00023 #include <vsl/vsl_vector_io.h>
00024 #include <vsl/vsl_binary_loader.h>
00025 #include <vpdfl/vpdfl_mixture.h>
00026 #include <mbl/mbl_data_wrapper.h>
00027 #include <mbl/mbl_data_array_wrapper.h>
00028 #include <vnl/vnl_math.h>
00029 
00030 #include <mbl/mbl_parse_block.h>
00031 #include <mbl/mbl_read_props.h>
00032 #include <vul/vul_string.h>
00033 #include <mbl/mbl_exception.h>
00034 
00035 // Weights smaller than this are assumed to be zero
00036 const double min_wt = 1e-8;
00037 
00038 //=======================================================================
00039 void vpdfl_mixture_builder::init()
00040 {
00041   min_var_ = 1.0e-6;
00042   max_its_ = 10;
00043   weights_fixed_ = false;
00044   initial_means_.clear();
00045 }
00046 
00047 //=======================================================================
00048 
00049 vpdfl_mixture_builder::vpdfl_mixture_builder()
00050 {
00051   init();
00052 }
00053 
00054 //=======================================================================
00055 
00056 vpdfl_mixture_builder::vpdfl_mixture_builder(const vpdfl_mixture_builder& b):
00057   vpdfl_builder_base()
00058 {
00059   init();
00060   *this = b;
00061 }
00062 
00063 //=======================================================================
00064 
00065 vpdfl_mixture_builder& vpdfl_mixture_builder::operator=(const vpdfl_mixture_builder& b)
00066 {
00067   if (&b==this) return *this;
00068 
00069   delete_stuff();
00070 
00071   unsigned int n = b.builder_.size();
00072   builder_.resize(n);
00073   for (unsigned int i=0;i<n;++i)
00074     builder_[i] = b.builder_[i]->clone();
00075 
00076   min_var_ = b.min_var_;
00077   max_its_ = b.max_its_;
00078   weights_fixed_ = b.weights_fixed_;
00079   initial_means_ = b.initial_means_;
00080   
00081   return *this;
00082 }
00083 
00084 //=======================================================================
00085 
00086 void vpdfl_mixture_builder::delete_stuff()
00087 {
00088   unsigned int n = builder_.size();
00089   for (unsigned int i=0;i<n;++i)
00090     delete builder_[i];
00091   builder_.resize(0);
00092   initial_means_.clear();
00093 }
00094 
00095 vpdfl_mixture_builder::~vpdfl_mixture_builder()
00096 {
00097   delete_stuff();
00098 }
00099 
00100 //=======================================================================
00101 
00102 //: Initialise n builders of type builder
00103 //  Clone taken of builder
00104 void vpdfl_mixture_builder::init(const vpdfl_builder_base& builder, int n)
00105 {
00106   delete_stuff();
00107   builder_.resize(n);
00108   for (int i=0;i<n;++i)
00109     builder_[i] = builder.clone();
00110 }
00111 
00112 //=======================================================================
00113 
00114 //: Define maximum number of EM iterations allowed
00115 void vpdfl_mixture_builder::set_max_iterations(int n)
00116 {
00117   max_its_ = n;
00118 }
00119 //: Define whether weights on components can change or not
00120 void vpdfl_mixture_builder::set_weights_fixed(bool b)
00121 {
00122   weights_fixed_ = b;
00123 }
00124 
00125 //=======================================================================
00126 
00127 //: Create empty model
00128 vpdfl_pdf_base* vpdfl_mixture_builder::new_model() const
00129 {
00130   return new vpdfl_mixture;
00131 }
00132 
00133 //=======================================================================
00134 
00135 //: Define lower threshold on variance for built models
00136 void vpdfl_mixture_builder::set_min_var(double min_var)
00137 {
00138   min_var_ = min_var;
00139 }
00140 
00141 //=======================================================================
00142 
00143 //: Get lower threshold on variance for built models
00144 double vpdfl_mixture_builder::min_var() const
00145 {
00146   return min_var_;
00147 }
00148 
00149 //=======================================================================
00150 
00151 //: Build default model with given mean
00152 void vpdfl_mixture_builder::build(vpdfl_pdf_base& /*model*/,
00153                                   const vnl_vector<double>& /*mean*/) const
00154 {
00155   vcl_cerr<<"vpdfl_mixture_builder::build(model,mean) Not yet implemented.\n";
00156   vcl_abort();
00157 }
00158 
00159 //=======================================================================
00160 
00161 //: Build model from data
00162 void vpdfl_mixture_builder::build(vpdfl_pdf_base& model,
00163                                   mbl_data_wrapper<vnl_vector<double> >& data) const
00164 {
00165   vcl_vector<double> wts(int(data.size()), 1.0);
00166   weighted_build(model,data,wts);
00167 }
00168 
00169 //=======================================================================
00170 
00171 //: Build model from weighted data
00172 void vpdfl_mixture_builder::weighted_build(vpdfl_pdf_base& base_model,
00173                                            mbl_data_wrapper<vnl_vector<double> >& data,
00174                                            const vcl_vector<double>& wts) const
00175 {
00176   assert(base_model.is_class("vpdfl_mixture"));
00177   vpdfl_mixture& model = static_cast<vpdfl_mixture&>( base_model);
00178 
00179   unsigned int n = builder_.size();
00180 
00181   bool model_setup = (model.n_components()==n);
00182 
00183   if (!model_setup)
00184   {
00185     // Create default model components
00186     model.clear();
00187     model.components().resize(n);
00188     model.weights().resize(n);
00189     for (unsigned int i=0;i<n;++i)
00190     {
00191       builder_[i]->set_min_var(min_var_);
00192       model.components()[i] = builder_[i]->new_model();
00193       model.weights()[i]      = 1.0/n;
00194     }
00195   }
00196 
00197   // Get vectors into an array for rapid access
00198   const vnl_vector<double>* data_ptr;
00199   vcl_vector<vnl_vector<double> > data_array;
00200 
00201   {
00202     unsigned int n=data.size();
00203     data.reset();
00204     data_array.resize(n);
00205     for (unsigned int i=0;i<n;++i)
00206     {
00207       data_array[i] = data.current();
00208       data.next();
00209     }
00210 
00211     data_ptr = &data_array[0]/*.begin()*/;
00212   }
00213 
00214   if (!model_setup || !initial_means_.empty())
00215     initialise(model,data_ptr,wts);
00216 
00217   vcl_vector<vnl_vector<double> > probs;
00218 
00219   int n_its = 0;
00220   double max_move = 1e-6;
00221   double move = max_move+1;
00222   while (move>max_move && n_its<max_its_)
00223   {
00224     e_step(model,probs,data_ptr,wts);
00225     move = m_step(model,probs,data_ptr,wts);
00226     n_its++;
00227   }
00228   calc_mean_and_variance(model);
00229   assert(model.is_valid_pdf());
00230 }
00231 
00232 static void UpdateRange(vnl_vector<double>& min_vec, vnl_vector<double>& max_vec, const vnl_vector<double>& vec)
00233 {
00234   unsigned int n=vec.size();
00235   for (unsigned int i=0;i<n;++i)
00236   {
00237     if (vec(i)<min_vec(i))
00238       min_vec(i)=vec(i);
00239     else
00240     if (vec(i)>max_vec(i))
00241       max_vec(i)=vec(i);
00242   }
00243 }
00244 
00245 //: Assumes means set up.  Estimates starting components.
00246 void vpdfl_mixture_builder::initialise_given_means(vpdfl_mixture& model,
00247                                                    const vnl_vector<double>* data,
00248                                                    const vcl_vector<vnl_vector<double> >& mean,
00249                                                    const vcl_vector<double>& wts) const
00250 {
00251   const unsigned int n_comp = builder_.size();
00252   const unsigned int n_samples = wts.size();
00253 
00254   // Compute range of data
00255   vnl_vector<double> min_v(mean[0]);
00256   vnl_vector<double> max_v(min_v);
00257   for (unsigned int i=1;i<n_comp;++i)
00258     UpdateRange(min_v,max_v,mean[i]);
00259 
00260   double mean_sep = vnl_vector_ssd(max_v,min_v)/n_samples;
00261   if (mean_sep<=1e-6) mean_sep = 1e-6;
00262 
00263 
00264   vcl_vector<double> wts_i(n_samples);
00265 
00266   mbl_data_array_wrapper<vnl_vector<double> > data_array(data,n_samples);
00267 
00268   for (unsigned int i=0;i<n_comp;++i)
00269   {
00270     // Compute weights proportional to inverse square to the mean
00271     double w_sum = 0.0;
00272     for (unsigned int j=0;j<n_samples;++j)
00273     {
00274       wts_i[j] = wts[j]*mean_sep/(mean_sep+ vnl_vector_ssd(data[j], mean[i]));
00275       w_sum+=wts_i[j];
00276     }
00277 
00278     // Normalise so weights add to n_samples/n_comp
00279     double f = n_samples/(n_comp*w_sum);
00280     for (unsigned int j=0;j<n_samples;++j)
00281       wts_i[j]*=f;
00282 
00283     // Build i'th component, biasing data toward mean(i)
00284     builder_[i]->weighted_build(*(model.components()[i]),data_array,wts_i);
00285   }
00286 }
00287 
00288 //=======================================================================
00289 
00290 void vpdfl_mixture_builder::initialise_diagonal(vpdfl_mixture& model,
00291                                                 const vnl_vector<double>* data,
00292                                                 const vcl_vector<double>& wts) const
00293 {
00294   // Build each component using randomly weighted data
00295   const unsigned int n_comp = builder_.size();
00296   const unsigned int n_samples = wts.size();
00297 
00298   // Compute range of data
00299   vnl_vector<double> min_v(data[0]);
00300   vnl_vector<double> max_v(min_v);
00301   for (unsigned int i=1;i<n_samples;++i)
00302     UpdateRange(min_v,max_v,data[i]);
00303 
00304 #if 0 // unused variable
00305   double mean_sep = vnl_vector_ssd(max_v,min_v)/n_samples;
00306 #endif
00307 
00308   // Create means along diagonal of bounding box
00309   vcl_vector<vnl_vector<double> > mean(n_comp);
00310   for (unsigned int i=0;i<n_comp;++i)
00311   {
00312     double f = (i+1.0)/(n_comp+1);
00313     mean[i] = (1-f)*min_v + f*max_v;
00314   }
00315 
00316   initialise_given_means(model,data,mean,wts);
00317 }
00318 
00319 //=======================================================================
00320 
00321 void vpdfl_mixture_builder::initialise_to_regular_samples(vpdfl_mixture& model,
00322                                                           const vnl_vector<double>* data,
00323                                                           const vcl_vector<double>& wts) const
00324 {
00325   // Build each component using randomly weighted data
00326   const unsigned int n_comp = builder_.size();
00327   const unsigned int n_samples = wts.size();
00328 
00329   double f = double(n_samples)/n_comp;
00330 
00331   // Select means from data
00332   vcl_vector<vnl_vector<double> > mean(n_comp);
00333   for (unsigned int i=0;i<n_comp;++i)
00334   {
00335     unsigned int j = vnl_math_rnd((i+0.5)*f); // must not be negative!
00336     if (j>=n_samples) j=n_samples-1;
00337       mean[i] = data[j];
00338   }
00339 
00340   initialise_given_means(model,data,mean,wts);
00341 }
00342 
00343 void vpdfl_mixture_builder::initialise(vpdfl_mixture& model,
00344           const vnl_vector<double>* data,
00345           const vcl_vector<double>& wts) const
00346 {
00347   // Later add a switch to decide on how to initialise
00348   if(!initial_means_.empty() )
00349   {
00350     initialise_given_means(model,data,initial_means_,wts);
00351   }
00352   else
00353   {
00354     initialise_to_regular_samples(model,data,wts);
00355   }
00356 }
00357 
00358 void vpdfl_mixture_builder::preset_initial_means(const vcl_vector<vnl_vector<double> >& component_means)
00359 {
00360     initial_means_ = component_means;
00361 }
00362 //=======================================================================
00363 
00364 void vpdfl_mixture_builder::e_step(vpdfl_mixture& model,
00365                                    vcl_vector<vnl_vector<double> >& probs,
00366                                    const vnl_vector<double>* data,
00367                                    const vcl_vector<double>& wts) const
00368 {
00369   const unsigned int n_comp = builder_.size();
00370   const unsigned int n_egs = wts.size();
00371   const vcl_vector<double>& m_wts = model.weights();
00372 
00373   if (probs.size()!=n_comp) probs.resize(n_comp);
00374 
00375   // Compute log probs
00376   // probs(i)(j+1) is logProb that e.g. j was drawn from component i
00377   for (unsigned int i=0;i<n_comp;++i)
00378   {
00379     if (probs[i].size()!=n_egs) probs[i].set_size(n_egs);
00380 
00381   // Any components with zero weights are ignored.
00382   // Eventually they should be pruned.
00383     if (m_wts[i]<=0) continue;
00384 
00385     double *p_data = probs[i].begin();
00386 
00387     double log_wt_i = vcl_log(m_wts[i]);
00388 
00389     for (unsigned int j=0;j<n_egs;++j)
00390     {
00391       p_data[j] = log_wt_i+model.components()[i]->log_p(data[j]);
00392     }
00393   }
00394 
00395   // Turn into probabilities and normalise.
00396   // Normalise so that sum_i probs(i)(j) = 1.0;
00397   for (unsigned int j=0;j<n_egs;++j)
00398   {
00399     // To minimise rounding errors, first find largest value
00400     double max_log_p=0;
00401     for (unsigned int i=0;i<n_comp;++i)
00402     {
00403       if (m_wts[i]<=0) continue;
00404       if (i==0 || probs[i](j)>max_log_p) max_log_p = probs[i](j);
00405     }
00406 
00407     // Turn into probabilities and sum
00408     double sum = 0.0;
00409     for (unsigned int i=0;i<n_comp;++i)
00410     {
00411       if (m_wts[i]<=0) continue;
00412       double p = vcl_exp(probs[i](j)-max_log_p);
00413       probs[i](j) = p;
00414       sum+=p;
00415     }
00416 
00417     // Divide through by sum to normalise
00418     if (sum>0.0)
00419       for (unsigned int i=0;i<n_comp;++i)
00420         probs[i](j)/=sum;
00421 
00422     if (sum<=0)
00423       vcl_cerr<<"vpdfl_mixture_builder::e_step() Zero sum for probs!\n";
00424   }
00425 }
00426 
00427 //=======================================================================
00428 
00429 double vpdfl_mixture_builder::m_step(vpdfl_mixture& model,
00430                                      const vcl_vector<vnl_vector<double> >& probs,
00431                                      const vnl_vector<double>* data,
00432                                      const vcl_vector<double>& wts) const
00433 {
00434   const unsigned int n_comp = builder_.size();
00435   const unsigned int n_egs = wts.size();
00436   vcl_vector<double> wts_i(n_egs);
00437 
00438   mbl_data_array_wrapper<vnl_vector<double> > data_array(data,n_egs);
00439 
00440   double move = 0.0;
00441   vnl_vector<double> old_mean;
00442 
00443   if (!weights_fixed_)
00444   {
00445     double w_sum = 0.0;
00446     // update the model weights
00447     for (unsigned int i=0;i<n_comp;++i)
00448     {
00449       model.weights()[i]=probs[i].mean();
00450 
00451       // Eliminate tiny components
00452       if (model.weights()[i]<min_wt) model.weights()[i]=0.0;
00453 
00454       w_sum += model.weights()[i];
00455     }
00456 
00457     // Ensure they add up to one
00458     for (unsigned int i=0;i<n_comp;++i)
00459     model.weights()[i]/=w_sum;
00460   }
00461 
00462   for (unsigned int i=0;i<n_comp;++i)
00463   {
00464     // Any components with zero weights are ignored.
00465   // Eventually they should be pruned.
00466     if (model.weights()[i]<=0.0) continue;
00467 
00468     // Compute weights
00469     const double* p = probs[i].begin();
00470     double w_sum = 0.0;
00471     for (unsigned int j=0;j<n_egs;++j)
00472     {
00473       wts_i[j] = wts[j]*p[j];
00474       w_sum += wts_i[j];
00475     }
00476 
00477     if (w_sum<=0.0)
00478       vcl_cerr<<"m_step: Dubious weights. sum="<<w_sum<<'\n';
00479 
00480     old_mean = model.components()[i]->mean();
00481     builder_[i]->weighted_build(*(model.components()[i]), data_array, wts_i);
00482 
00483     move += vnl_vector_ssd(old_mean, model.components()[i]->mean());
00484   }
00485 
00486 
00487   return move;
00488 }
00489 
00490 //=======================================================================
00491 
00492 //: Add Y*v to X
00493 static inline void incXbyYv(vnl_vector<double> *X, const vnl_vector<double> &Y, double v)
00494 {
00495   assert(X->size() == Y.size());
00496   int i = ((int)X->size()) - 1;
00497   double * const pX=X->data_block();
00498   while (i >= 0)
00499   {
00500     pX[i] += Y[i] * v;
00501     i--;
00502   }
00503 }
00504 
00505 //: Add (Y + Z.*Z)*v to X
00506 static inline void incXbyYplusXXv(vnl_vector<double> *X, const vnl_vector<double> &Y,
00507                                const vnl_vector<double> &Z, double v)
00508 {
00509   assert(X->size() == Y.size());
00510   int i = ((int)X->size()) - 1;
00511   double * const pX=X->data_block();
00512   while (i >= 0)
00513   {
00514     pX[i] += (Y[i] + vnl_math_sqr(Z[i]))* v;
00515     i--;
00516   }
00517 }
00518 
00519 
00520 //: Calculate and set the mixture's mean and variance.
00521 void vpdfl_mixture_builder::calc_mean_and_variance(vpdfl_mixture& model)
00522 {
00523   unsigned int n = model.component(0).mean().size();
00524   vnl_vector<double> mean(n, 0.0);
00525   vnl_vector<double> var(n, 0.0);
00526 
00527   for (unsigned int i=0; i<model.n_components(); ++i)
00528   {
00529     incXbyYv(&mean, model.component(i).mean(), model.weight(i));
00530     incXbyYplusXXv(&var, model.component(i).variance(),
00531     model.component(i).mean(), model.weight(i));
00532   }
00533 
00534   for (unsigned int i=0; i<n; ++i)
00535     var(i) -= vnl_math_sqr(mean(i));
00536 
00537   model.set_mean_and_variance(mean, var);
00538 }
00539 
00540 //=======================================================================
00541 
00542 vcl_string vpdfl_mixture_builder::is_a() const
00543 {
00544   return vcl_string("vpdfl_mixture_builder");
00545 }
00546 
00547 //=======================================================================
00548 
00549 bool vpdfl_mixture_builder::is_class(vcl_string const& s) const
00550 {
00551   return vpdfl_builder_base::is_class(s) || s==vpdfl_mixture_builder::is_a();
00552 }
00553 
00554 //=======================================================================
00555 
00556 short vpdfl_mixture_builder::version_no() const
00557 {
00558   return 1;
00559 }
00560 
00561 //=======================================================================
00562 
00563 vpdfl_builder_base* vpdfl_mixture_builder::clone() const
00564 {
00565   return new vpdfl_mixture_builder(*this);
00566 }
00567 
00568 //=======================================================================
00569 
00570 void vpdfl_mixture_builder::print_summary(vcl_ostream& os) const
00571 {
00572   if (weights_fixed_) os<<vsl_indent()<<"Weights fixed"<<'\n';
00573   else                os<<vsl_indent()<<"Weights may vary"<<'\n';
00574   os<<vsl_indent()<<"Max iterations: "<<max_its_<<'\n';
00575   for (unsigned int i=0;i<builder_.size();++i)
00576   {
00577     os<<vsl_indent()<<"Builder "<<i<<": ";
00578     vsl_print_summary(os, builder_[i]); os << '\n';
00579   }
00580 }
00581 
00582 //=======================================================================
00583 
00584 void vpdfl_mixture_builder::b_write(vsl_b_ostream& bfs) const
00585 {
00586   vsl_b_write(bfs,is_a());
00587   vsl_b_write(bfs,version_no());
00588   vsl_b_write(bfs,builder_);
00589   vsl_b_write(bfs,max_its_);
00590   vsl_b_write(bfs,weights_fixed_);
00591 }
00592 
00593 //=======================================================================
00594 
00595 void vpdfl_mixture_builder::b_read(vsl_b_istream& bfs)
00596 {
00597   if (!bfs) return;
00598 
00599   vcl_string name;
00600   vsl_b_read(bfs,name);
00601   if (name != is_a())
00602   {
00603     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture_builder &)\n"
00604              << "           Attempted to load object of type "
00605              << name <<" into object of type " << is_a() << '\n';
00606     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00607     return;
00608   }
00609 
00610   delete_stuff();
00611 
00612   short version;
00613   vsl_b_read(bfs,version);
00614   switch (version)
00615   {
00616     case (1):
00617       vsl_b_read(bfs,builder_);
00618       vsl_b_read(bfs,max_its_);
00619       vsl_b_read(bfs,weights_fixed_);
00620       break;
00621     default:
00622       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture_builder &)\n"
00623                << "           Unknown version number "<< version << '\n';
00624       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00625       return;
00626   }
00627 }
00628 //: Read initialisation settings from a stream.
00629 // Parameters:
00630 // \verbatim
00631 // {
00632 //   min_var: 1.0e-6
00633 //   n_pdfs: 3
00634 //   // Type of basis pdf
00635 //   basis_pdf: axis_gaussian { min_var: 0.0001 }
00636 // }
00637 // \endverbatim
00638 // \throw mbl_exception_parse_error if the parse fails.
00639 void vpdfl_mixture_builder::config_from_stream(vcl_istream & is)
00640 {
00641   vcl_string s = mbl_parse_block(is);
00642 
00643   vcl_istringstream ss(s);
00644   mbl_read_props_type props = mbl_read_props_ws(ss);
00645 
00646   double mv=1.0e-6;
00647   if (props.find("min_var")!=props.end())
00648   {
00649     mv=vul_string_atof(props["min_var"]);
00650     props.erase("min_var");
00651   }
00652   set_min_var(mv);
00653 
00654   unsigned n_pdfs = 2;
00655   if (props.find("n_pdfs")!=props.end())
00656   {
00657     n_pdfs=vul_string_atoi(props["n_pdfs"]);
00658     props.erase("n_pdfs");
00659   }
00660 
00661   max_its_=10;
00662   if (props.find("max_its")!=props.end())
00663   {
00664     max_its_=vul_string_atoi(props["max_its"]);
00665     props.erase("max_its");
00666   }
00667 
00668   weights_fixed_=false;
00669   if (props.find("weights_fixed")!=props.end())
00670   {
00671     weights_fixed_=vul_string_to_bool(props["weights_fixed"]);
00672     props.erase("weights_fixed");
00673   }
00674 
00675   if (props.find("basis_pdf")!=props.end())
00676   {
00677     vcl_istringstream pdf_ss(props["basis_pdf"]);
00678     vcl_auto_ptr<vpdfl_builder_base>
00679             b = vpdfl_builder_base::new_pdf_builder_from_stream(pdf_ss);
00680     init(*b,n_pdfs);
00681     props.erase("basis_pdf");
00682   }
00683 
00684   try
00685   {
00686     mbl_read_props_look_for_unused_props(
00687         "vpdfl_mixture_builder::config_from_stream", props);
00688   }
00689 
00690   catch(mbl_exception_unused_props &e)
00691   {
00692     throw mbl_exception_parse_error(e.what());
00693   }
00694 
00695 }
00696 
00697 //==================< end of file: vpdfl_mixture_builder.cxx >====================