00001
00002
00003
00004
00005
00006
00007 #include "vpdfl_mixture_builder.h"
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <vcl_cassert.h>
00020 #include <vcl_cmath.h>
00021 #include <vcl_cstdlib.h>
00022 #include <vsl/vsl_indent.h>
00023 #include <vsl/vsl_vector_io.h>
00024 #include <vsl/vsl_binary_loader.h>
00025 #include <vpdfl/vpdfl_mixture.h>
00026 #include <mbl/mbl_data_wrapper.h>
00027 #include <mbl/mbl_data_array_wrapper.h>
00028 #include <vnl/vnl_math.h>
00029
00030 #include <mbl/mbl_parse_block.h>
00031 #include <mbl/mbl_read_props.h>
00032 #include <vul/vul_string.h>
00033 #include <mbl/mbl_exception.h>
00034
00035
00036 const double min_wt = 1e-8;
00037
00038
00039 void vpdfl_mixture_builder::init()
00040 {
00041 min_var_ = 1.0e-6;
00042 max_its_ = 10;
00043 weights_fixed_ = false;
00044 initial_means_.clear();
00045 }
00046
00047
00048
00049 vpdfl_mixture_builder::vpdfl_mixture_builder()
00050 {
00051 init();
00052 }
00053
00054
00055
00056 vpdfl_mixture_builder::vpdfl_mixture_builder(const vpdfl_mixture_builder& b):
00057 vpdfl_builder_base()
00058 {
00059 init();
00060 *this = b;
00061 }
00062
00063
00064
00065 vpdfl_mixture_builder& vpdfl_mixture_builder::operator=(const vpdfl_mixture_builder& b)
00066 {
00067 if (&b==this) return *this;
00068
00069 delete_stuff();
00070
00071 unsigned int n = b.builder_.size();
00072 builder_.resize(n);
00073 for (unsigned int i=0;i<n;++i)
00074 builder_[i] = b.builder_[i]->clone();
00075
00076 min_var_ = b.min_var_;
00077 max_its_ = b.max_its_;
00078 weights_fixed_ = b.weights_fixed_;
00079 initial_means_ = b.initial_means_;
00080
00081 return *this;
00082 }
00083
00084
00085
00086 void vpdfl_mixture_builder::delete_stuff()
00087 {
00088 unsigned int n = builder_.size();
00089 for (unsigned int i=0;i<n;++i)
00090 delete builder_[i];
00091 builder_.resize(0);
00092 initial_means_.clear();
00093 }
00094
00095 vpdfl_mixture_builder::~vpdfl_mixture_builder()
00096 {
00097 delete_stuff();
00098 }
00099
00100
00101
00102
00103
00104 void vpdfl_mixture_builder::init(const vpdfl_builder_base& builder, int n)
00105 {
00106 delete_stuff();
00107 builder_.resize(n);
00108 for (int i=0;i<n;++i)
00109 builder_[i] = builder.clone();
00110 }
00111
00112
00113
00114
00115 void vpdfl_mixture_builder::set_max_iterations(int n)
00116 {
00117 max_its_ = n;
00118 }
00119
00120 void vpdfl_mixture_builder::set_weights_fixed(bool b)
00121 {
00122 weights_fixed_ = b;
00123 }
00124
00125
00126
00127
00128 vpdfl_pdf_base* vpdfl_mixture_builder::new_model() const
00129 {
00130 return new vpdfl_mixture;
00131 }
00132
00133
00134
00135
00136 void vpdfl_mixture_builder::set_min_var(double min_var)
00137 {
00138 min_var_ = min_var;
00139 }
00140
00141
00142
00143
00144 double vpdfl_mixture_builder::min_var() const
00145 {
00146 return min_var_;
00147 }
00148
00149
00150
00151
00152 void vpdfl_mixture_builder::build(vpdfl_pdf_base& ,
00153 const vnl_vector<double>& ) const
00154 {
00155 vcl_cerr<<"vpdfl_mixture_builder::build(model,mean) Not yet implemented.\n";
00156 vcl_abort();
00157 }
00158
00159
00160
00161
00162 void vpdfl_mixture_builder::build(vpdfl_pdf_base& model,
00163 mbl_data_wrapper<vnl_vector<double> >& data) const
00164 {
00165 vcl_vector<double> wts(int(data.size()), 1.0);
00166 weighted_build(model,data,wts);
00167 }
00168
00169
00170
00171
00172 void vpdfl_mixture_builder::weighted_build(vpdfl_pdf_base& base_model,
00173 mbl_data_wrapper<vnl_vector<double> >& data,
00174 const vcl_vector<double>& wts) const
00175 {
00176 assert(base_model.is_class("vpdfl_mixture"));
00177 vpdfl_mixture& model = static_cast<vpdfl_mixture&>( base_model);
00178
00179 unsigned int n = builder_.size();
00180
00181 bool model_setup = (model.n_components()==n);
00182
00183 if (!model_setup)
00184 {
00185
00186 model.clear();
00187 model.components().resize(n);
00188 model.weights().resize(n);
00189 for (unsigned int i=0;i<n;++i)
00190 {
00191 builder_[i]->set_min_var(min_var_);
00192 model.components()[i] = builder_[i]->new_model();
00193 model.weights()[i] = 1.0/n;
00194 }
00195 }
00196
00197
00198 const vnl_vector<double>* data_ptr;
00199 vcl_vector<vnl_vector<double> > data_array;
00200
00201 {
00202 unsigned int n=data.size();
00203 data.reset();
00204 data_array.resize(n);
00205 for (unsigned int i=0;i<n;++i)
00206 {
00207 data_array[i] = data.current();
00208 data.next();
00209 }
00210
00211 data_ptr = &data_array[0];
00212 }
00213
00214 if (!model_setup || !initial_means_.empty())
00215 initialise(model,data_ptr,wts);
00216
00217 vcl_vector<vnl_vector<double> > probs;
00218
00219 int n_its = 0;
00220 double max_move = 1e-6;
00221 double move = max_move+1;
00222 while (move>max_move && n_its<max_its_)
00223 {
00224 e_step(model,probs,data_ptr,wts);
00225 move = m_step(model,probs,data_ptr,wts);
00226 n_its++;
00227 }
00228 calc_mean_and_variance(model);
00229 assert(model.is_valid_pdf());
00230 }
00231
00232 static void UpdateRange(vnl_vector<double>& min_vec, vnl_vector<double>& max_vec, const vnl_vector<double>& vec)
00233 {
00234 unsigned int n=vec.size();
00235 for (unsigned int i=0;i<n;++i)
00236 {
00237 if (vec(i)<min_vec(i))
00238 min_vec(i)=vec(i);
00239 else
00240 if (vec(i)>max_vec(i))
00241 max_vec(i)=vec(i);
00242 }
00243 }
00244
00245
00246 void vpdfl_mixture_builder::initialise_given_means(vpdfl_mixture& model,
00247 const vnl_vector<double>* data,
00248 const vcl_vector<vnl_vector<double> >& mean,
00249 const vcl_vector<double>& wts) const
00250 {
00251 const unsigned int n_comp = builder_.size();
00252 const unsigned int n_samples = wts.size();
00253
00254
00255 vnl_vector<double> min_v(mean[0]);
00256 vnl_vector<double> max_v(min_v);
00257 for (unsigned int i=1;i<n_comp;++i)
00258 UpdateRange(min_v,max_v,mean[i]);
00259
00260 double mean_sep = vnl_vector_ssd(max_v,min_v)/n_samples;
00261 if (mean_sep<=1e-6) mean_sep = 1e-6;
00262
00263
00264 vcl_vector<double> wts_i(n_samples);
00265
00266 mbl_data_array_wrapper<vnl_vector<double> > data_array(data,n_samples);
00267
00268 for (unsigned int i=0;i<n_comp;++i)
00269 {
00270
00271 double w_sum = 0.0;
00272 for (unsigned int j=0;j<n_samples;++j)
00273 {
00274 wts_i[j] = wts[j]*mean_sep/(mean_sep+ vnl_vector_ssd(data[j], mean[i]));
00275 w_sum+=wts_i[j];
00276 }
00277
00278
00279 double f = n_samples/(n_comp*w_sum);
00280 for (unsigned int j=0;j<n_samples;++j)
00281 wts_i[j]*=f;
00282
00283
00284 builder_[i]->weighted_build(*(model.components()[i]),data_array,wts_i);
00285 }
00286 }
00287
00288
00289
00290 void vpdfl_mixture_builder::initialise_diagonal(vpdfl_mixture& model,
00291 const vnl_vector<double>* data,
00292 const vcl_vector<double>& wts) const
00293 {
00294
00295 const unsigned int n_comp = builder_.size();
00296 const unsigned int n_samples = wts.size();
00297
00298
00299 vnl_vector<double> min_v(data[0]);
00300 vnl_vector<double> max_v(min_v);
00301 for (unsigned int i=1;i<n_samples;++i)
00302 UpdateRange(min_v,max_v,data[i]);
00303
00304 #if 0 // unused variable
00305 double mean_sep = vnl_vector_ssd(max_v,min_v)/n_samples;
00306 #endif
00307
00308
00309 vcl_vector<vnl_vector<double> > mean(n_comp);
00310 for (unsigned int i=0;i<n_comp;++i)
00311 {
00312 double f = (i+1.0)/(n_comp+1);
00313 mean[i] = (1-f)*min_v + f*max_v;
00314 }
00315
00316 initialise_given_means(model,data,mean,wts);
00317 }
00318
00319
00320
00321 void vpdfl_mixture_builder::initialise_to_regular_samples(vpdfl_mixture& model,
00322 const vnl_vector<double>* data,
00323 const vcl_vector<double>& wts) const
00324 {
00325
00326 const unsigned int n_comp = builder_.size();
00327 const unsigned int n_samples = wts.size();
00328
00329 double f = double(n_samples)/n_comp;
00330
00331
00332 vcl_vector<vnl_vector<double> > mean(n_comp);
00333 for (unsigned int i=0;i<n_comp;++i)
00334 {
00335 unsigned int j = vnl_math_rnd((i+0.5)*f);
00336 if (j>=n_samples) j=n_samples-1;
00337 mean[i] = data[j];
00338 }
00339
00340 initialise_given_means(model,data,mean,wts);
00341 }
00342
00343 void vpdfl_mixture_builder::initialise(vpdfl_mixture& model,
00344 const vnl_vector<double>* data,
00345 const vcl_vector<double>& wts) const
00346 {
00347
00348 if(!initial_means_.empty() )
00349 {
00350 initialise_given_means(model,data,initial_means_,wts);
00351 }
00352 else
00353 {
00354 initialise_to_regular_samples(model,data,wts);
00355 }
00356 }
00357
00358 void vpdfl_mixture_builder::preset_initial_means(const vcl_vector<vnl_vector<double> >& component_means)
00359 {
00360 initial_means_ = component_means;
00361 }
00362
00363
00364 void vpdfl_mixture_builder::e_step(vpdfl_mixture& model,
00365 vcl_vector<vnl_vector<double> >& probs,
00366 const vnl_vector<double>* data,
00367 const vcl_vector<double>& wts) const
00368 {
00369 const unsigned int n_comp = builder_.size();
00370 const unsigned int n_egs = wts.size();
00371 const vcl_vector<double>& m_wts = model.weights();
00372
00373 if (probs.size()!=n_comp) probs.resize(n_comp);
00374
00375
00376
00377 for (unsigned int i=0;i<n_comp;++i)
00378 {
00379 if (probs[i].size()!=n_egs) probs[i].set_size(n_egs);
00380
00381
00382
00383 if (m_wts[i]<=0) continue;
00384
00385 double *p_data = probs[i].begin();
00386
00387 double log_wt_i = vcl_log(m_wts[i]);
00388
00389 for (unsigned int j=0;j<n_egs;++j)
00390 {
00391 p_data[j] = log_wt_i+model.components()[i]->log_p(data[j]);
00392 }
00393 }
00394
00395
00396
00397 for (unsigned int j=0;j<n_egs;++j)
00398 {
00399
00400 double max_log_p=0;
00401 for (unsigned int i=0;i<n_comp;++i)
00402 {
00403 if (m_wts[i]<=0) continue;
00404 if (i==0 || probs[i](j)>max_log_p) max_log_p = probs[i](j);
00405 }
00406
00407
00408 double sum = 0.0;
00409 for (unsigned int i=0;i<n_comp;++i)
00410 {
00411 if (m_wts[i]<=0) continue;
00412 double p = vcl_exp(probs[i](j)-max_log_p);
00413 probs[i](j) = p;
00414 sum+=p;
00415 }
00416
00417
00418 if (sum>0.0)
00419 for (unsigned int i=0;i<n_comp;++i)
00420 probs[i](j)/=sum;
00421
00422 if (sum<=0)
00423 vcl_cerr<<"vpdfl_mixture_builder::e_step() Zero sum for probs!\n";
00424 }
00425 }
00426
00427
00428
00429 double vpdfl_mixture_builder::m_step(vpdfl_mixture& model,
00430 const vcl_vector<vnl_vector<double> >& probs,
00431 const vnl_vector<double>* data,
00432 const vcl_vector<double>& wts) const
00433 {
00434 const unsigned int n_comp = builder_.size();
00435 const unsigned int n_egs = wts.size();
00436 vcl_vector<double> wts_i(n_egs);
00437
00438 mbl_data_array_wrapper<vnl_vector<double> > data_array(data,n_egs);
00439
00440 double move = 0.0;
00441 vnl_vector<double> old_mean;
00442
00443 if (!weights_fixed_)
00444 {
00445 double w_sum = 0.0;
00446
00447 for (unsigned int i=0;i<n_comp;++i)
00448 {
00449 model.weights()[i]=probs[i].mean();
00450
00451
00452 if (model.weights()[i]<min_wt) model.weights()[i]=0.0;
00453
00454 w_sum += model.weights()[i];
00455 }
00456
00457
00458 for (unsigned int i=0;i<n_comp;++i)
00459 model.weights()[i]/=w_sum;
00460 }
00461
00462 for (unsigned int i=0;i<n_comp;++i)
00463 {
00464
00465
00466 if (model.weights()[i]<=0.0) continue;
00467
00468
00469 const double* p = probs[i].begin();
00470 double w_sum = 0.0;
00471 for (unsigned int j=0;j<n_egs;++j)
00472 {
00473 wts_i[j] = wts[j]*p[j];
00474 w_sum += wts_i[j];
00475 }
00476
00477 if (w_sum<=0.0)
00478 vcl_cerr<<"m_step: Dubious weights. sum="<<w_sum<<'\n';
00479
00480 old_mean = model.components()[i]->mean();
00481 builder_[i]->weighted_build(*(model.components()[i]), data_array, wts_i);
00482
00483 move += vnl_vector_ssd(old_mean, model.components()[i]->mean());
00484 }
00485
00486
00487 return move;
00488 }
00489
00490
00491
00492
00493 static inline void incXbyYv(vnl_vector<double> *X, const vnl_vector<double> &Y, double v)
00494 {
00495 assert(X->size() == Y.size());
00496 int i = ((int)X->size()) - 1;
00497 double * const pX=X->data_block();
00498 while (i >= 0)
00499 {
00500 pX[i] += Y[i] * v;
00501 i--;
00502 }
00503 }
00504
00505
00506 static inline void incXbyYplusXXv(vnl_vector<double> *X, const vnl_vector<double> &Y,
00507 const vnl_vector<double> &Z, double v)
00508 {
00509 assert(X->size() == Y.size());
00510 int i = ((int)X->size()) - 1;
00511 double * const pX=X->data_block();
00512 while (i >= 0)
00513 {
00514 pX[i] += (Y[i] + vnl_math_sqr(Z[i]))* v;
00515 i--;
00516 }
00517 }
00518
00519
00520
00521 void vpdfl_mixture_builder::calc_mean_and_variance(vpdfl_mixture& model)
00522 {
00523 unsigned int n = model.component(0).mean().size();
00524 vnl_vector<double> mean(n, 0.0);
00525 vnl_vector<double> var(n, 0.0);
00526
00527 for (unsigned int i=0; i<model.n_components(); ++i)
00528 {
00529 incXbyYv(&mean, model.component(i).mean(), model.weight(i));
00530 incXbyYplusXXv(&var, model.component(i).variance(),
00531 model.component(i).mean(), model.weight(i));
00532 }
00533
00534 for (unsigned int i=0; i<n; ++i)
00535 var(i) -= vnl_math_sqr(mean(i));
00536
00537 model.set_mean_and_variance(mean, var);
00538 }
00539
00540
00541
00542 vcl_string vpdfl_mixture_builder::is_a() const
00543 {
00544 return vcl_string("vpdfl_mixture_builder");
00545 }
00546
00547
00548
00549 bool vpdfl_mixture_builder::is_class(vcl_string const& s) const
00550 {
00551 return vpdfl_builder_base::is_class(s) || s==vpdfl_mixture_builder::is_a();
00552 }
00553
00554
00555
00556 short vpdfl_mixture_builder::version_no() const
00557 {
00558 return 1;
00559 }
00560
00561
00562
00563 vpdfl_builder_base* vpdfl_mixture_builder::clone() const
00564 {
00565 return new vpdfl_mixture_builder(*this);
00566 }
00567
00568
00569
00570 void vpdfl_mixture_builder::print_summary(vcl_ostream& os) const
00571 {
00572 if (weights_fixed_) os<<vsl_indent()<<"Weights fixed"<<'\n';
00573 else os<<vsl_indent()<<"Weights may vary"<<'\n';
00574 os<<vsl_indent()<<"Max iterations: "<<max_its_<<'\n';
00575 for (unsigned int i=0;i<builder_.size();++i)
00576 {
00577 os<<vsl_indent()<<"Builder "<<i<<": ";
00578 vsl_print_summary(os, builder_[i]); os << '\n';
00579 }
00580 }
00581
00582
00583
00584 void vpdfl_mixture_builder::b_write(vsl_b_ostream& bfs) const
00585 {
00586 vsl_b_write(bfs,is_a());
00587 vsl_b_write(bfs,version_no());
00588 vsl_b_write(bfs,builder_);
00589 vsl_b_write(bfs,max_its_);
00590 vsl_b_write(bfs,weights_fixed_);
00591 }
00592
00593
00594
00595 void vpdfl_mixture_builder::b_read(vsl_b_istream& bfs)
00596 {
00597 if (!bfs) return;
00598
00599 vcl_string name;
00600 vsl_b_read(bfs,name);
00601 if (name != is_a())
00602 {
00603 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture_builder &)\n"
00604 << " Attempted to load object of type "
00605 << name <<" into object of type " << is_a() << '\n';
00606 bfs.is().clear(vcl_ios::badbit);
00607 return;
00608 }
00609
00610 delete_stuff();
00611
00612 short version;
00613 vsl_b_read(bfs,version);
00614 switch (version)
00615 {
00616 case (1):
00617 vsl_b_read(bfs,builder_);
00618 vsl_b_read(bfs,max_its_);
00619 vsl_b_read(bfs,weights_fixed_);
00620 break;
00621 default:
00622 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture_builder &)\n"
00623 << " Unknown version number "<< version << '\n';
00624 bfs.is().clear(vcl_ios::badbit);
00625 return;
00626 }
00627 }
00628
00629
00630
00631
00632
00633
00634
00635
00636
00637
00638
00639 void vpdfl_mixture_builder::config_from_stream(vcl_istream & is)
00640 {
00641 vcl_string s = mbl_parse_block(is);
00642
00643 vcl_istringstream ss(s);
00644 mbl_read_props_type props = mbl_read_props_ws(ss);
00645
00646 double mv=1.0e-6;
00647 if (props.find("min_var")!=props.end())
00648 {
00649 mv=vul_string_atof(props["min_var"]);
00650 props.erase("min_var");
00651 }
00652 set_min_var(mv);
00653
00654 unsigned n_pdfs = 2;
00655 if (props.find("n_pdfs")!=props.end())
00656 {
00657 n_pdfs=vul_string_atoi(props["n_pdfs"]);
00658 props.erase("n_pdfs");
00659 }
00660
00661 max_its_=10;
00662 if (props.find("max_its")!=props.end())
00663 {
00664 max_its_=vul_string_atoi(props["max_its"]);
00665 props.erase("max_its");
00666 }
00667
00668 weights_fixed_=false;
00669 if (props.find("weights_fixed")!=props.end())
00670 {
00671 weights_fixed_=vul_string_to_bool(props["weights_fixed"]);
00672 props.erase("weights_fixed");
00673 }
00674
00675 if (props.find("basis_pdf")!=props.end())
00676 {
00677 vcl_istringstream pdf_ss(props["basis_pdf"]);
00678 vcl_auto_ptr<vpdfl_builder_base>
00679 b = vpdfl_builder_base::new_pdf_builder_from_stream(pdf_ss);
00680 init(*b,n_pdfs);
00681 props.erase("basis_pdf");
00682 }
00683
00684 try
00685 {
00686 mbl_read_props_look_for_unused_props(
00687 "vpdfl_mixture_builder::config_from_stream", props);
00688 }
00689
00690 catch(mbl_exception_unused_props &e)
00691 {
00692 throw mbl_exception_parse_error(e.what());
00693 }
00694
00695 }
00696
00697