contrib/mul/mmn/mmn_dp_solver.cxx
Go to the documentation of this file.
00001 #include "mmn_dp_solver.h"
00002 //:
00003 // \file
00004 // \brief Solve restricted class of Markov problems (trees and tri-trees)
00005 // \author Tim Cootes
00006 
00007 #include <mmn/mmn_order_cost.h>
00008 #include <mmn/mmn_graph_rep1.h>
00009 #include <vcl_cassert.h>
00010 #include <vcl_cstdlib.h>
00011 
00012 #include <mbl/mbl_parse_block.h>
00013 #include <mbl/mbl_read_props.h>
00014 
00015 //: Default constructor
00016 mmn_dp_solver::mmn_dp_solver()
00017 {
00018 }
00019 
00020 //: Input the arcs that define the graph
00021 void mmn_dp_solver::set_arcs(unsigned num_nodes,
00022                              const vcl_vector<mmn_arc>& arcs)
00023 {
00024   // Copy in arcs, and ensure ordering v1<v2
00025   vcl_vector<mmn_arc> ordered_arcs(arcs.size());
00026   for (unsigned i=0;i<arcs.size();++i)
00027   {
00028     if (arcs[i].v1<arcs[i].v2)
00029       ordered_arcs[i]= arcs[i];
00030     else
00031       ordered_arcs[i]= mmn_arc(arcs[i].v2,arcs[i].v1);
00032   }
00033 
00034   mmn_graph_rep1 graph;
00035   graph.build(num_nodes,ordered_arcs);
00036   vcl_vector<mmn_dependancy> deps;
00037   if (!graph.compute_dependancies(deps,0))
00038   {
00039     vcl_cerr<<"Graph cannot be decomposed - too complex.\n"
00040             <<"Arc list: ";
00041     for (unsigned i=0;i<arcs.size();++i) vcl_cout<<arcs[i];
00042     vcl_cerr<<'\n';
00043     vcl_abort();
00044   }
00045 
00046   set_dependancies(deps,num_nodes,graph.max_n_arcs());
00047 }
00048 
00049 
00050 //: Index of root node
00051 unsigned mmn_dp_solver::root() const
00052 {
00053   if (deps_.size()==0) return 0;
00054   return deps_[deps_.size()-1].v1;
00055 }
00056 
00057 //: Define dependencies
00058 void mmn_dp_solver::set_dependancies(const vcl_vector<mmn_dependancy>& deps,
00059                                      unsigned n_nodes, unsigned max_n_arcs)
00060 {
00061   deps_ = deps;
00062   nc_.resize(n_nodes);
00063   pc_.resize(max_n_arcs);
00064   index1_.resize(n_nodes);
00065   index2_.resize(n_nodes);
00066 }
00067 
00068 void mmn_dp_solver::process_dep1(const mmn_dependancy& dep)
00069 {
00070   // dep->v0 depends on dep->v1 through arc dep->arc1
00071   const vnl_vector<double>& nc0 = nc_[dep.v0];
00072   vnl_vector<double>& nc1 = nc_[dep.v1];
00073   vnl_matrix<double>& p = pc_[dep.arc1];
00074   vcl_vector<unsigned>& i0 = index1_[dep.v0];
00075 
00076   // Check sizes of matrices
00077   if (dep.v0<dep.v1)
00078   {
00079     assert(p.rows()==nc0.size());
00080     assert(p.cols()==nc1.size());
00081   }
00082   else
00083   {
00084     if (p.rows()!=nc1.size())
00085     {
00086       vcl_cerr<<"p.rows()="<<p.rows()<<"p.cols()="<<p.cols()
00087               <<" nc0.size()="<<nc0.size()
00088               <<" nc1.size()="<<nc1.size()<<vcl_endl
00089               <<"dep: "<<dep<<vcl_endl;
00090     }
00091     assert(p.rows()==nc1.size());
00092     assert(p.cols()==nc0.size());
00093   }
00094 
00095   // Set i0[i1] to the optimal choice of node v0 if v1 is i1
00096   i0.resize(nc1.size());
00097   for (unsigned j=0;j<nc1.size();++j)
00098   {
00099     double min_v;
00100     unsigned best_i=0;
00101     if (dep.v0<dep.v1)
00102     {
00103       min_v=nc0[0]+p(0,j);
00104       for (unsigned i=1;i<nc0.size();++i)
00105       {
00106         double v=nc0[i]+p(i,j);
00107         if (v<min_v) { min_v=v; best_i=i; }
00108       }
00109     }
00110     else
00111     {
00112       min_v=nc0[0]+p(j,0);
00113       for (unsigned i=1;i<nc0.size();++i)
00114       {
00115         double v=nc0[i]+p(j,i);
00116         if (v<min_v) { min_v=v; best_i=i; }
00117       }
00118     }
00119     i0[j]=best_i;
00120     nc1[j]+=min_v;  // Update costs for node v1
00121   }
00122 }
00123 
00124 void mmn_dp_solver::process_dep2(const mmn_dependancy& dep)
00125 {
00126   // n_dep==2
00127   // dep->v0 depends on dep->v1 and dep->v2
00128   // dep->v0 depends on dep->v1 through arc dep->arc1
00129   const vnl_vector<double>& nc0 = nc_[dep.v0];
00130   const vnl_vector<double>& nc1 = nc_[dep.v1];
00131   const vnl_vector<double>& nc2 = nc_[dep.v2];
00132   const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00133   const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00134   vnl_matrix<double>& pa12 = pc_[dep.arc12];
00135   vnl_matrix<int>& ind0 = index2_[dep.v0];
00136 
00137   if (pa12.size()==0)
00138   {
00139     if (dep.v1<dep.v2)
00140       pa12.set_size(nc1.size(),nc2.size());
00141     else
00142       pa12.set_size(nc2.size(),nc1.size());
00143     pa12.fill(0.0);
00144   }
00145 
00146   // i0[i1,i2] to the optimal choice of node v0 if v1 is i1, v2 is i2
00147   ind0.set_size(nc1.size(),nc2.size());
00148 
00149   for (unsigned i1=0;i1<nc1.size();++i1)
00150   {
00151     vnl_vector<double> sum0(nc0);
00152     if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00153     else               sum0+=pa1.get_row(i1);
00154 
00155     for (unsigned i2=0;i2<nc2.size();++i2)
00156     {
00157       vnl_vector<double> sum(sum0);
00158       if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00159       else               sum+=pa2.get_row(i2);
00160 
00161       // sum[i] is the cost of choosing i, given (i1,i2)
00162       // Select minimum
00163       unsigned best_i=0;
00164       double min_v=sum[0];
00165       for (unsigned i=1;i<sum.size();++i)
00166         if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00167 
00168       // Record position of minima
00169       ind0(i1,i2)=best_i;
00170       // Update pairwise cost for arc between v1 and v2
00171       if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00172       else               { pa12(i2,i1)+=min_v; }
00173     }
00174   }
00175 }
00176 
00177 
00178 //: Compute optimal choice for dep.v0 given v1 and v2
00179 //  Includes cost depending on (v0,v1,v2) as well as pairwise and 
00180 //  node costs.
00181 // tri_cost(i,j,k) is cost of associating smallest node index
00182 // with i, next with j and largest node index with k.
00183 void mmn_dp_solver::process_dep2t(const mmn_dependancy& dep,
00184                                   const vil_image_view<double>& tri_cost)
00185 {
00186   // n_dep==2
00187   // dep->v0 depends on dep->v1 and dep->v2
00188   // dep->v0 depends on dep->v1 through arc dep->arc1
00189   const vnl_vector<double>& nc0 = nc_[dep.v0];
00190   const vnl_vector<double>& nc1 = nc_[dep.v1];
00191   const vnl_vector<double>& nc2 = nc_[dep.v2];
00192   const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00193   const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00194   vnl_matrix<double>& pa12 = pc_[dep.arc12];
00195   vnl_matrix<int>& ind0 = index2_[dep.v0];
00196 
00197   // Create a re-ordered view of tri_cost, so we can use tc(i1,i2,i3)
00198   vil_image_view<double> tc=mmn_unorder_cost(tri_cost,
00199                                              dep.v0,dep.v1,dep.v2);
00200   vcl_ptrdiff_t tc_step0=tc.istep();
00201 
00202   if (pa12.size()==0)
00203   {
00204     if (dep.v1<dep.v2)
00205       pa12.set_size(nc1.size(),nc2.size());
00206     else
00207       pa12.set_size(nc2.size(),nc1.size());
00208     pa12.fill(0.0);
00209   }
00210 
00211   // i0[i1,i2] to the optimal choice of node v0 if v1 is i1, v2 is i2
00212   ind0.set_size(nc1.size(),nc2.size());
00213 
00214   for (unsigned i1=0;i1<nc1.size();++i1)
00215   {
00216     vnl_vector<double> sum0(nc0);
00217     if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00218     else               sum0+=pa1.get_row(i1);
00219 
00220     for (unsigned i2=0;i2<nc2.size();++i2)
00221     {
00222       vnl_vector<double> sum(sum0);
00223       if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00224       else               sum+=pa2.get_row(i2);
00225 
00226       // sum[i] is the cost of choosing i, given (i1,i2)
00227       // Select minimum
00228       unsigned best_i=0;
00229       const double *tci=&tc(0,i1,i2);
00230       double min_v=sum[0]+tci[0];
00231       tci+=tc_step0; // move to element 1
00232       for (unsigned i=1;i<sum.size();++i,tci+=tc_step0)
00233       {
00234         sum[i]+=(*tci);
00235         if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00236       }
00237 
00238       // Record position of minima
00239       ind0(i1,i2)=best_i;
00240       // Update pairwise cost for arc between v1 and v2
00241       if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00242       else               { pa12(i2,i1)+=min_v; }
00243     }
00244   }
00245 }
00246 
00247 
00248 double mmn_dp_solver::solve(
00249                  const vcl_vector<vnl_vector<double> >& node_cost,
00250                  const vcl_vector<vnl_matrix<double> >& pair_cost,
00251                  vcl_vector<unsigned>& x)
00252 {
00253   nc_ = node_cost;
00254   for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00255   for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00256 
00257   if (deps_.size()==0)
00258   {
00259     vcl_cerr<<"No dependencies.\n";
00260     return 999.99;
00261   }
00262 
00263   // Process dependencies in given order
00264   vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00265   for (;dep!=deps_.end();dep++)
00266   {
00267     if (dep->n_dep==1) process_dep1(*dep);
00268     else               process_dep2(*dep);
00269   }
00270 
00271   const vnl_vector<double>& root_cost = nc_[root()];
00272   unsigned best_i=0;
00273   double min_v=root_cost[0];
00274   for (unsigned i=1;i<root_cost.size();++i)
00275     if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00276 
00277   backtrace(best_i,x);
00278   return min_v;
00279 }
00280 
00281 double mmn_dp_solver::solve(
00282                  const vcl_vector<vnl_vector<double> >& node_cost,
00283                  const vcl_vector<vnl_matrix<double> >& pair_cost,
00284                  const vcl_vector<vil_image_view<double> >& tri_cost,
00285                  vcl_vector<unsigned>& x)
00286 {
00287   nc_ = node_cost;
00288   for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00289   for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00290 
00291   if (deps_.size()==0)
00292   {
00293     vcl_cerr<<"No dependencies.\n";
00294     return 999.99;
00295   }
00296 
00297   // Process dependencies in given order
00298   vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00299   for (;dep!=deps_.end();dep++)
00300   {
00301     if (dep->n_dep==1) process_dep1(*dep);
00302     else
00303     {
00304       if (dep->tri1==mmn_no_tri) process_dep2(*dep);
00305       else
00306       {
00307         // dep->v0 depends on arcs and a triplet relationship
00308         assert(dep->tri1 < tri_cost.size());
00309         process_dep2t(*dep,tri_cost[dep->tri1]);
00310       }
00311     }
00312   }
00313 
00314   const vnl_vector<double>& root_cost = nc_[root()];
00315   unsigned best_i=0;
00316   double min_v=root_cost[0];
00317   for (unsigned i=1;i<root_cost.size();++i)
00318     if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00319 
00320   backtrace(best_i,x);
00321   return min_v;
00322 }
00323 
00324 
00325 //: Compute optimal values for x[i] given that root node is root_value
00326 //  Assumes that solve() has been already called.
00327 void mmn_dp_solver::backtrace(unsigned root_value,vcl_vector<unsigned>& x)
00328 {
00329   x.resize(nc_.size());
00330   x[root()]=root_value;
00331 
00332   // Perform backtracing to find optimal solution.
00333   for (int i=deps_.size()-1; i>=0; --i)
00334   {
00335     unsigned v0=deps_[i].v0;
00336     unsigned v1=deps_[i].v1;
00337     if (deps_[i].n_dep==1)
00338        x[v0]=index1_[v0][x[v1]];
00339     else
00340     {
00341       const vnl_matrix<int>& ind0 = index2_[v0];
00342       x[v0]=ind0(x[v1],x[deps_[i].v2]);
00343     }
00344   }
00345 }
00346 
00347 //=======================================================================
00348 // Method: set_from_stream
00349 //=======================================================================
00350 //: Initialise from a string stream
00351 bool mmn_dp_solver::set_from_stream(vcl_istream &is)
00352 {
00353   // Cycle through stream and produce a map of properties
00354   vcl_string s = mbl_parse_block(is);
00355   vcl_istringstream ss(s);
00356   mbl_read_props_type props = mbl_read_props_ws(ss);
00357 
00358   // No properties expected.
00359 
00360   // Check for unused props
00361   mbl_read_props_look_for_unused_props(
00362       "mmn_dp_solver::set_from_stream", props, mbl_read_props_type());
00363   return true;
00364 }
00365 
00366 
00367 //=======================================================================
00368 // Method: version_no
00369 //=======================================================================
00370 
00371 short mmn_dp_solver::version_no() const
00372 {
00373   return 1;
00374 }
00375 
00376 //=======================================================================
00377 // Method: is_a
00378 //=======================================================================
00379 
00380 vcl_string mmn_dp_solver::is_a() const
00381 {
00382   return vcl_string("mmn_dp_solver");
00383 }
00384 
00385 //: Create a copy on the heap and return base class pointer
00386 mmn_solver* mmn_dp_solver::clone() const
00387 {
00388   return new mmn_dp_solver(*this);
00389 }
00390 
00391 //=======================================================================
00392 // Method: print
00393 //=======================================================================
00394 
00395 void mmn_dp_solver::print_summary(vcl_ostream& /*os*/) const
00396 {
00397 }
00398 
00399 //=======================================================================
00400 // Method: save
00401 //=======================================================================
00402 
00403 void mmn_dp_solver::b_write(vsl_b_ostream& bfs) const
00404 {
00405   vsl_b_write(bfs,version_no());
00406   vsl_b_write(bfs,unsigned(deps_.size()));
00407   for (unsigned i=0;i<deps_.size();++i)
00408     vsl_b_write(bfs,deps_[i]);
00409 }
00410 
00411 //=======================================================================
00412 // Method: load
00413 //=======================================================================
00414 
00415 void mmn_dp_solver::b_read(vsl_b_istream& bfs)
00416 {
00417   if (!bfs) return;
00418   short version;
00419   unsigned n;
00420   vsl_b_read(bfs,version);
00421   switch (version)
00422   {
00423     case (1):
00424       vsl_b_read(bfs,n);
00425       deps_.resize(n);
00426       for (unsigned i=0;i<n;++i) vsl_b_read(bfs,deps_[i]);
00427       break;
00428     default:
00429       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00430                << "           Unknown version number "<< version << vcl_endl;
00431       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00432       return;
00433   }
00434 }
00435