00001 #include "mmn_dp_solver.h"
00002
00003
00004
00005
00006
00007 #include <mmn/mmn_order_cost.h>
00008 #include <mmn/mmn_graph_rep1.h>
00009 #include <vcl_cassert.h>
00010 #include <vcl_cstdlib.h>
00011
00012 #include <mbl/mbl_parse_block.h>
00013 #include <mbl/mbl_read_props.h>
00014
00015
00016 mmn_dp_solver::mmn_dp_solver()
00017 {
00018 }
00019
00020
00021 void mmn_dp_solver::set_arcs(unsigned num_nodes,
00022 const vcl_vector<mmn_arc>& arcs)
00023 {
00024
00025 vcl_vector<mmn_arc> ordered_arcs(arcs.size());
00026 for (unsigned i=0;i<arcs.size();++i)
00027 {
00028 if (arcs[i].v1<arcs[i].v2)
00029 ordered_arcs[i]= arcs[i];
00030 else
00031 ordered_arcs[i]= mmn_arc(arcs[i].v2,arcs[i].v1);
00032 }
00033
00034 mmn_graph_rep1 graph;
00035 graph.build(num_nodes,ordered_arcs);
00036 vcl_vector<mmn_dependancy> deps;
00037 if (!graph.compute_dependancies(deps,0))
00038 {
00039 vcl_cerr<<"Graph cannot be decomposed - too complex.\n"
00040 <<"Arc list: ";
00041 for (unsigned i=0;i<arcs.size();++i) vcl_cout<<arcs[i];
00042 vcl_cerr<<'\n';
00043 vcl_abort();
00044 }
00045
00046 set_dependancies(deps,num_nodes,graph.max_n_arcs());
00047 }
00048
00049
00050
00051 unsigned mmn_dp_solver::root() const
00052 {
00053 if (deps_.size()==0) return 0;
00054 return deps_[deps_.size()-1].v1;
00055 }
00056
00057
00058 void mmn_dp_solver::set_dependancies(const vcl_vector<mmn_dependancy>& deps,
00059 unsigned n_nodes, unsigned max_n_arcs)
00060 {
00061 deps_ = deps;
00062 nc_.resize(n_nodes);
00063 pc_.resize(max_n_arcs);
00064 index1_.resize(n_nodes);
00065 index2_.resize(n_nodes);
00066 }
00067
00068 void mmn_dp_solver::process_dep1(const mmn_dependancy& dep)
00069 {
00070
00071 const vnl_vector<double>& nc0 = nc_[dep.v0];
00072 vnl_vector<double>& nc1 = nc_[dep.v1];
00073 vnl_matrix<double>& p = pc_[dep.arc1];
00074 vcl_vector<unsigned>& i0 = index1_[dep.v0];
00075
00076
00077 if (dep.v0<dep.v1)
00078 {
00079 assert(p.rows()==nc0.size());
00080 assert(p.cols()==nc1.size());
00081 }
00082 else
00083 {
00084 if (p.rows()!=nc1.size())
00085 {
00086 vcl_cerr<<"p.rows()="<<p.rows()<<"p.cols()="<<p.cols()
00087 <<" nc0.size()="<<nc0.size()
00088 <<" nc1.size()="<<nc1.size()<<vcl_endl
00089 <<"dep: "<<dep<<vcl_endl;
00090 }
00091 assert(p.rows()==nc1.size());
00092 assert(p.cols()==nc0.size());
00093 }
00094
00095
00096 i0.resize(nc1.size());
00097 for (unsigned j=0;j<nc1.size();++j)
00098 {
00099 double min_v;
00100 unsigned best_i=0;
00101 if (dep.v0<dep.v1)
00102 {
00103 min_v=nc0[0]+p(0,j);
00104 for (unsigned i=1;i<nc0.size();++i)
00105 {
00106 double v=nc0[i]+p(i,j);
00107 if (v<min_v) { min_v=v; best_i=i; }
00108 }
00109 }
00110 else
00111 {
00112 min_v=nc0[0]+p(j,0);
00113 for (unsigned i=1;i<nc0.size();++i)
00114 {
00115 double v=nc0[i]+p(j,i);
00116 if (v<min_v) { min_v=v; best_i=i; }
00117 }
00118 }
00119 i0[j]=best_i;
00120 nc1[j]+=min_v;
00121 }
00122 }
00123
00124 void mmn_dp_solver::process_dep2(const mmn_dependancy& dep)
00125 {
00126
00127
00128
00129 const vnl_vector<double>& nc0 = nc_[dep.v0];
00130 const vnl_vector<double>& nc1 = nc_[dep.v1];
00131 const vnl_vector<double>& nc2 = nc_[dep.v2];
00132 const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00133 const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00134 vnl_matrix<double>& pa12 = pc_[dep.arc12];
00135 vnl_matrix<int>& ind0 = index2_[dep.v0];
00136
00137 if (pa12.size()==0)
00138 {
00139 if (dep.v1<dep.v2)
00140 pa12.set_size(nc1.size(),nc2.size());
00141 else
00142 pa12.set_size(nc2.size(),nc1.size());
00143 pa12.fill(0.0);
00144 }
00145
00146
00147 ind0.set_size(nc1.size(),nc2.size());
00148
00149 for (unsigned i1=0;i1<nc1.size();++i1)
00150 {
00151 vnl_vector<double> sum0(nc0);
00152 if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00153 else sum0+=pa1.get_row(i1);
00154
00155 for (unsigned i2=0;i2<nc2.size();++i2)
00156 {
00157 vnl_vector<double> sum(sum0);
00158 if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00159 else sum+=pa2.get_row(i2);
00160
00161
00162
00163 unsigned best_i=0;
00164 double min_v=sum[0];
00165 for (unsigned i=1;i<sum.size();++i)
00166 if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00167
00168
00169 ind0(i1,i2)=best_i;
00170
00171 if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00172 else { pa12(i2,i1)+=min_v; }
00173 }
00174 }
00175 }
00176
00177
00178
00179
00180
00181
00182
00183 void mmn_dp_solver::process_dep2t(const mmn_dependancy& dep,
00184 const vil_image_view<double>& tri_cost)
00185 {
00186
00187
00188
00189 const vnl_vector<double>& nc0 = nc_[dep.v0];
00190 const vnl_vector<double>& nc1 = nc_[dep.v1];
00191 const vnl_vector<double>& nc2 = nc_[dep.v2];
00192 const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00193 const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00194 vnl_matrix<double>& pa12 = pc_[dep.arc12];
00195 vnl_matrix<int>& ind0 = index2_[dep.v0];
00196
00197
00198 vil_image_view<double> tc=mmn_unorder_cost(tri_cost,
00199 dep.v0,dep.v1,dep.v2);
00200 vcl_ptrdiff_t tc_step0=tc.istep();
00201
00202 if (pa12.size()==0)
00203 {
00204 if (dep.v1<dep.v2)
00205 pa12.set_size(nc1.size(),nc2.size());
00206 else
00207 pa12.set_size(nc2.size(),nc1.size());
00208 pa12.fill(0.0);
00209 }
00210
00211
00212 ind0.set_size(nc1.size(),nc2.size());
00213
00214 for (unsigned i1=0;i1<nc1.size();++i1)
00215 {
00216 vnl_vector<double> sum0(nc0);
00217 if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00218 else sum0+=pa1.get_row(i1);
00219
00220 for (unsigned i2=0;i2<nc2.size();++i2)
00221 {
00222 vnl_vector<double> sum(sum0);
00223 if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00224 else sum+=pa2.get_row(i2);
00225
00226
00227
00228 unsigned best_i=0;
00229 const double *tci=&tc(0,i1,i2);
00230 double min_v=sum[0]+tci[0];
00231 tci+=tc_step0;
00232 for (unsigned i=1;i<sum.size();++i,tci+=tc_step0)
00233 {
00234 sum[i]+=(*tci);
00235 if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00236 }
00237
00238
00239 ind0(i1,i2)=best_i;
00240
00241 if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00242 else { pa12(i2,i1)+=min_v; }
00243 }
00244 }
00245 }
00246
00247
00248 double mmn_dp_solver::solve(
00249 const vcl_vector<vnl_vector<double> >& node_cost,
00250 const vcl_vector<vnl_matrix<double> >& pair_cost,
00251 vcl_vector<unsigned>& x)
00252 {
00253 nc_ = node_cost;
00254 for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00255 for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00256
00257 if (deps_.size()==0)
00258 {
00259 vcl_cerr<<"No dependencies.\n";
00260 return 999.99;
00261 }
00262
00263
00264 vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00265 for (;dep!=deps_.end();dep++)
00266 {
00267 if (dep->n_dep==1) process_dep1(*dep);
00268 else process_dep2(*dep);
00269 }
00270
00271 const vnl_vector<double>& root_cost = nc_[root()];
00272 unsigned best_i=0;
00273 double min_v=root_cost[0];
00274 for (unsigned i=1;i<root_cost.size();++i)
00275 if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00276
00277 backtrace(best_i,x);
00278 return min_v;
00279 }
00280
00281 double mmn_dp_solver::solve(
00282 const vcl_vector<vnl_vector<double> >& node_cost,
00283 const vcl_vector<vnl_matrix<double> >& pair_cost,
00284 const vcl_vector<vil_image_view<double> >& tri_cost,
00285 vcl_vector<unsigned>& x)
00286 {
00287 nc_ = node_cost;
00288 for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00289 for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00290
00291 if (deps_.size()==0)
00292 {
00293 vcl_cerr<<"No dependencies.\n";
00294 return 999.99;
00295 }
00296
00297
00298 vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00299 for (;dep!=deps_.end();dep++)
00300 {
00301 if (dep->n_dep==1) process_dep1(*dep);
00302 else
00303 {
00304 if (dep->tri1==mmn_no_tri) process_dep2(*dep);
00305 else
00306 {
00307
00308 assert(dep->tri1 < tri_cost.size());
00309 process_dep2t(*dep,tri_cost[dep->tri1]);
00310 }
00311 }
00312 }
00313
00314 const vnl_vector<double>& root_cost = nc_[root()];
00315 unsigned best_i=0;
00316 double min_v=root_cost[0];
00317 for (unsigned i=1;i<root_cost.size();++i)
00318 if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00319
00320 backtrace(best_i,x);
00321 return min_v;
00322 }
00323
00324
00325
00326
00327 void mmn_dp_solver::backtrace(unsigned root_value,vcl_vector<unsigned>& x)
00328 {
00329 x.resize(nc_.size());
00330 x[root()]=root_value;
00331
00332
00333 for (int i=deps_.size()-1; i>=0; --i)
00334 {
00335 unsigned v0=deps_[i].v0;
00336 unsigned v1=deps_[i].v1;
00337 if (deps_[i].n_dep==1)
00338 x[v0]=index1_[v0][x[v1]];
00339 else
00340 {
00341 const vnl_matrix<int>& ind0 = index2_[v0];
00342 x[v0]=ind0(x[v1],x[deps_[i].v2]);
00343 }
00344 }
00345 }
00346
00347
00348
00349
00350
00351 bool mmn_dp_solver::set_from_stream(vcl_istream &is)
00352 {
00353
00354 vcl_string s = mbl_parse_block(is);
00355 vcl_istringstream ss(s);
00356 mbl_read_props_type props = mbl_read_props_ws(ss);
00357
00358
00359
00360
00361 mbl_read_props_look_for_unused_props(
00362 "mmn_dp_solver::set_from_stream", props, mbl_read_props_type());
00363 return true;
00364 }
00365
00366
00367
00368
00369
00370
00371 short mmn_dp_solver::version_no() const
00372 {
00373 return 1;
00374 }
00375
00376
00377
00378
00379
00380 vcl_string mmn_dp_solver::is_a() const
00381 {
00382 return vcl_string("mmn_dp_solver");
00383 }
00384
00385
00386 mmn_solver* mmn_dp_solver::clone() const
00387 {
00388 return new mmn_dp_solver(*this);
00389 }
00390
00391
00392
00393
00394
00395 void mmn_dp_solver::print_summary(vcl_ostream& ) const
00396 {
00397 }
00398
00399
00400
00401
00402
00403 void mmn_dp_solver::b_write(vsl_b_ostream& bfs) const
00404 {
00405 vsl_b_write(bfs,version_no());
00406 vsl_b_write(bfs,unsigned(deps_.size()));
00407 for (unsigned i=0;i<deps_.size();++i)
00408 vsl_b_write(bfs,deps_[i]);
00409 }
00410
00411
00412
00413
00414
00415 void mmn_dp_solver::b_read(vsl_b_istream& bfs)
00416 {
00417 if (!bfs) return;
00418 short version;
00419 unsigned n;
00420 vsl_b_read(bfs,version);
00421 switch (version)
00422 {
00423 case (1):
00424 vsl_b_read(bfs,n);
00425 deps_.resize(n);
00426 for (unsigned i=0;i<n;++i) vsl_b_read(bfs,deps_[i]);
00427 break;
00428 default:
00429 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00430 << " Unknown version number "<< version << vcl_endl;
00431 bfs.is().clear(vcl_ios::badbit);
00432 return;
00433 }
00434 }
00435