00001
00002
00003
00004
00005
00006 #include <mmn/mmn_dp_solver.h>
00007 #include <vcl_cassert.h>
00008
00009
00010 mmn_dp_solver::mmn_dp_solver()
00011 {
00012 }
00013
00014
00015 unsigned mmn_dp_solver::root() const
00016 {
00017 if (deps_.size()==0) return 0;
00018 return deps_[deps_.size()-1].v1;
00019 }
00020
00021
00022 void mmn_dp_solver::set_dependancies(const vcl_vector<mmn_dependancy>& deps,
00023 unsigned n_nodes, unsigned max_n_arcs)
00024 {
00025 deps_ = deps;
00026 nc_.resize(n_nodes);
00027 pc_.resize(max_n_arcs);
00028 index1_.resize(n_nodes);
00029 index2_.resize(n_nodes);
00030 }
00031
00032 void mmn_dp_solver::process_dep1(const mmn_dependancy& dep)
00033 {
00034
00035 const vnl_vector<double>& nc0 = nc_[dep.v0];
00036 vnl_vector<double>& nc1 = nc_[dep.v1];
00037 vnl_matrix<double>& p = pc_[dep.arc1];
00038 vcl_vector<unsigned>& i0 = index1_[dep.v0];
00039
00040
00041 if (dep.v0<dep.v1)
00042 {
00043 assert(p.rows()==nc0.size());
00044 assert(p.cols()==nc1.size());
00045 }
00046 else
00047 {
00048 assert(p.rows()==nc1.size());
00049 assert(p.cols()==nc0.size());
00050 }
00051
00052
00053 i0.resize(nc1.size());
00054 for (unsigned j=0;j<nc1.size();++j)
00055 {
00056 double min_v;
00057 unsigned best_i=0;
00058 if (dep.v0<dep.v1)
00059 {
00060 min_v=nc0[0]+p(0,j);
00061 for (unsigned i=1;i<nc0.size();++i)
00062 {
00063 double v=nc0[i]+p(i,j);
00064 if (v<min_v) { min_v=v; best_i=i; }
00065 }
00066 }
00067 else
00068 {
00069 min_v=nc0[0]+p(j,0);
00070 for (unsigned i=1;i<nc0.size();++i)
00071 {
00072 double v=nc0[i]+p(j,i);
00073 if (v<min_v) { min_v=v; best_i=i; }
00074 }
00075 }
00076 i0[j]=best_i;
00077 nc1[j]+=min_v;
00078 }
00079 }
00080
00081 void mmn_dp_solver::process_dep2(const mmn_dependancy& dep)
00082 {
00083
00084
00085
00086 const vnl_vector<double>& nc0 = nc_[dep.v0];
00087 const vnl_vector<double>& nc1 = nc_[dep.v1];
00088 const vnl_vector<double>& nc2 = nc_[dep.v2];
00089 const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00090 const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00091 vnl_matrix<double>& pa12 = pc_[dep.arc12];
00092 vnl_matrix<int>& ind0 = index2_[dep.v0];
00093
00094 if (pa12.size()==0)
00095 {
00096 if (dep.v1<dep.v2)
00097 pa12.set_size(nc1.size(),nc2.size());
00098 else
00099 pa12.set_size(nc2.size(),nc1.size());
00100 pa12.fill(0.0);
00101 }
00102
00103
00104 ind0.set_size(nc1.size(),nc2.size());
00105
00106 for (unsigned i1=0;i1<nc1.size();++i1)
00107 {
00108 vnl_vector<double> sum0(nc0);
00109 if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00110 else sum0+=pa1.get_row(i1);
00111
00112 for (unsigned i2=0;i2<nc2.size();++i2)
00113 {
00114 vnl_vector<double> sum(sum0);
00115 if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00116 else sum+=pa2.get_row(i2);
00117
00118
00119
00120 unsigned best_i=0;
00121 double min_v=sum[0];
00122 for (unsigned i=1;i<sum.size();++i)
00123 if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00124
00125
00126 ind0(i1,i2)=best_i;
00127
00128 if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00129 else { pa12(i2,i1)+=min_v; }
00130 }
00131 }
00132 }
00133
00134 double mmn_dp_solver::solve(const vcl_vector<vnl_vector<double> >& node_cost,
00135 const vcl_vector<vnl_matrix<double> >& pair_cost,
00136 vcl_vector<unsigned>& x)
00137 {
00138 nc_ = node_cost;
00139 for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00140 for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00141
00142 if (deps_.size()==0)
00143 {
00144 vcl_cerr<<"No dependencies.\n";
00145 return 999.99;
00146 }
00147
00148
00149 vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00150 for (;dep!=deps_.end();dep++)
00151 {
00152 if (dep->n_dep==1) process_dep1(*dep);
00153 else process_dep2(*dep);
00154 }
00155
00156 const vnl_vector<double>& root_cost = nc_[root()];
00157 unsigned best_i=0;
00158 double min_v=root_cost[0];
00159 for (unsigned i=1;i<root_cost.size();++i)
00160 if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00161
00162 backtrace(best_i,x);
00163 return min_v;
00164 }
00165
00166
00167
00168 void mmn_dp_solver::backtrace(unsigned root_value,vcl_vector<unsigned>& x)
00169 {
00170 x.resize(nc_.size());
00171 x[root()]=root_value;
00172
00173
00174 for (int i=deps_.size()-1; i>=0; --i)
00175 {
00176 unsigned v0=deps_[i].v0;
00177 unsigned v1=deps_[i].v1;
00178 if (deps_[i].n_dep==1)
00179 x[v0]=index1_[v0][x[v1]];
00180 else
00181 {
00182 const vnl_matrix<int>& ind0 = index2_[v0];
00183 x[v0]=ind0(x[v1],x[deps_[i].v2]);
00184 }
00185 }
00186 }