00001
00002
00003
00004
00005
00006
00007 #include "mbl_matrix_products.h"
00008 #include <vnl/vnl_vector.h>
00009 #include <vnl/vnl_matrix.h>
00010 #include <vcl_cassert.h>
00011 #include <vcl_cstdlib.h>
00012 #include <vcl_iostream.h>
00013
00014
00015
00016
00017 void mbl_matrix_product(vnl_matrix<double>& AB, const vnl_matrix<double>& A,
00018 const vnl_matrix<double>& B)
00019 {
00020 unsigned int nr1 = A.rows();
00021 unsigned int nc1 = A.cols();
00022 unsigned int nr2 = B.rows();
00023 unsigned int nc2 = B.cols();
00024
00025 if ( nr2 != nc1 )
00026 {
00027 vcl_cerr<<"Product : B.rows != A.cols\n";
00028 vcl_abort() ;
00029 }
00030
00031 if ( (AB.rows()!=nr1) || (AB.cols()!= nc2) )
00032 AB.set_size( nr1, nc2 ) ;
00033
00034 const double *const * AData = A.data_array();
00035 const double *const * BData = B.data_array();
00036 double ** RData = AB.data_array();
00037
00038
00039 AB.fill(0);
00040
00041 for (unsigned int r=0; r < nr1; ++r)
00042 {
00043 const double* A_row = AData[r];
00044 double* R_row = RData[r]-1;
00045 for (unsigned int c=0; c < nc1 ; ++c )
00046 {
00047 double a = A_row[c];
00048 if (a==0.0) continue;
00049
00050 const double* B_row = BData[c]-1;
00051 int i = nc2+1;
00052 while (--i)
00053 {
00054 R_row[i] += a * B_row[i];
00055 }
00056 }
00057 }
00058 }
00059
00060
00061
00062
00063 void mbl_matrix_product_a_bt(vnl_matrix<double>& ABt,
00064 const vnl_matrix<double>& A,
00065 const vnl_matrix<double>& B)
00066 {
00067 int nc1 = A.columns();
00068 #ifndef NDEBUG
00069 int nc2 = B.columns();
00070 if ( nc2 != nc1 )
00071 {
00072 vcl_cerr<<"mbl_matrix_product_a_bt : B.columns != A.columns\n";
00073 vcl_abort();
00074 }
00075 #endif //!NDEBUG
00076
00077 mbl_matrix_product_a_bt(ABt,A,B,nc1);
00078 }
00079
00080
00081
00082
00083 void mbl_matrix_product_a_bt(vnl_matrix<double>& ABt,
00084 const vnl_matrix<double>& A,
00085 const vnl_matrix<double>& B,
00086 int nc)
00087 {
00088 unsigned int nr1 = A.rows();
00089 unsigned int nr2 = B.rows();
00090
00091 assert(A.columns()>=(unsigned int)nc);
00092 assert(B.columns()>=(unsigned int)nc);
00093
00094 if ( (ABt.rows()!=nr1) || (ABt.columns()!= nr2) )
00095 ABt.set_size( nr1, nr2 ) ;
00096
00097 double const *const * A_data = A.data_array();
00098 double const *const * B_data = B.data_array();
00099 double ** R_data = ABt.data_array();
00100
00101 for (unsigned int r=0;r<nr1;++r)
00102 {
00103 const double* A_row = A_data[r];
00104 double* R_row = R_data[r];
00105 for (unsigned int c=0;c<nr2;++c)
00106 {
00107 const double* B_row = B_data[c];
00108 R_row[c] = vnl_c_vector<double>::dot_product(A_row,B_row,nc);
00109 }
00110 }
00111 }
00112
00113
00114
00115
00116 void mbl_matrix_product_at_b(vnl_matrix<double>& AtB,
00117 const vnl_matrix<double>& A,
00118 const vnl_matrix<double>& B)
00119 {
00120 mbl_matrix_product_at_b(AtB,A,B,A.columns());
00121 }
00122
00123
00124
00125
00126
00127 void mbl_matrix_product_a_at(vnl_matrix<double>& AAt,
00128 const vnl_matrix<double>& A,
00129 unsigned nr, unsigned nc)
00130 {
00131 assert(nr<=A.rows());
00132 assert(nc<=A.columns());
00133
00134 if ( (AAt.rows()!=nr) || (AAt.columns()!= nr) )
00135 AAt.set_size( nr, nr ) ;
00136
00137 double const *const * A_data = A.data_array();
00138 double ** R_data = AAt.data_array();
00139
00140
00141 for (unsigned int r=0;r<nr;++r)
00142 {
00143 const double* A_row = A_data[r];
00144 double* R_row = R_data[r];
00145 for (unsigned int c=r;c<nr;++c)
00146 {
00147 const double* B_row = A_data[c];
00148 R_row[c] = vnl_c_vector<double>::dot_product(A_row,B_row,nc);
00149 }
00150 }
00151
00152
00153 for (unsigned int r=1;r<nr;++r)
00154 for (unsigned int c=0;c<r;++c)
00155 AAt(r,c)=AAt(c,r);
00156 }
00157
00158
00159
00160
00161
00162
00163 void mbl_matrix_product_a_at(vnl_matrix<double>& AAt,
00164 const vnl_matrix<double>& A)
00165 {
00166 mbl_matrix_product_a_at(AAt,A,A.rows(),A.columns());
00167 }
00168
00169
00170
00171 void mbl_matrix_product_at_b(vnl_matrix<double>& AtB,
00172 const vnl_matrix<double>& A,
00173 const vnl_matrix<double>& B,
00174 int nc_a)
00175 {
00176 assert(nc_a >= 0 && A.columns()>=(unsigned int)nc_a);
00177 unsigned int nr1 = A.rows();
00178 unsigned int nr2 = B.rows();
00179 unsigned int nc2 = B.columns();
00180
00181 if ( nr2 != nr1 )
00182 {
00183 vcl_cerr<<"TC_ProductAtB : B.rows != A.rows\n";
00184 vcl_abort();
00185 }
00186
00187 if ( (AtB.rows()!=(unsigned int)nc_a) || (AtB.columns()!= nc2) )
00188 AtB.set_size( nc_a, nc2 ) ;
00189
00190 double const *const * A_data = A.data_array();
00191 double const *const * B_data = B.data_array();
00192 double ** R_data = AtB.data_array()-1;
00193
00194 AtB.fill(0);
00195
00196 for (unsigned int r1 = 0; r1<nr1; ++r1)
00197 {
00198 const double* A_row = A_data[r1]-1;
00199 const double* B_row = B_data[r1]-1;
00200 double a;
00201 int c1 = nc_a+1;
00202 while (--c1)
00203 {
00204 double *R_row = R_data[c1]-1;
00205 a = A_row[c1];
00206 int c2 = nc2+1;
00207 while (--c2)
00208 {
00209 R_row[c2] +=a*B_row[c2];
00210 }
00211 }
00212 }
00213 }
00214
00215
00216
00217
00218
00219 void mbl_matrix_product_at_a(vnl_matrix<double>& AtA,
00220 const vnl_matrix<double>& A,
00221 unsigned nc)
00222 {
00223 assert(A.columns()>=nc);
00224 unsigned int nr = A.rows();
00225
00226 if ( AtA.rows()!=nr || (AtA.columns()!= nc) )
00227 AtA.set_size( nc, nc ) ;
00228
00229 double const *const * A_data = A.data_array();
00230 double ** R_data = AtA.data_array()-1;
00231
00232 AtA.fill(0);
00233
00234 for (unsigned int r = 0; r<nr; ++r)
00235 {
00236 const double* A_row = A_data[r]-1;
00237 double a;
00238 int c1 = nc+1;
00239 while (--c1)
00240 {
00241 double *R_row = R_data[c1]-1;
00242 a = A_row[c1];
00243 int c2 = nc+1;
00244 while (--c2)
00245 {
00246 R_row[c2] +=a*A_row[c2];
00247 }
00248 }
00249 }
00250 }
00251
00252
00253
00254
00255 void mbl_matrix_product_at_a(vnl_matrix<double>& AtA,
00256 const vnl_matrix<double>& A)
00257 {
00258 mbl_matrix_product_at_a(AtA,A,A.columns());
00259 }
00260
00261
00262
00263 void mbl_matrix_product_adb(vnl_matrix<double>& ADB,
00264 const vnl_matrix<double>& A,
00265 const vnl_vector<double>& d,
00266 const vnl_matrix<double>& B)
00267 {
00268 unsigned int nr1 = A.rows();
00269 unsigned int nc1 = A.cols();
00270 unsigned int nc2 = B.cols();
00271
00272 assert ( B.rows() == nc1 );
00273
00274 assert ( B.rows() == d.size() );
00275
00276 if ( (ADB.rows()!=nr1) || (ADB.cols()!= nc2) )
00277 ADB.set_size( nr1, nc2 ) ;
00278
00279 const double * const* AData = A.data_array();
00280 const double * const* BData = B.data_array();
00281 const double * d_data = d.data_block();
00282 double ** ADBdata = ADB.data_array();
00283
00284 ADB.fill(0);
00285
00286 for (unsigned int r=0; r < nr1; ++r)
00287 {
00288 const double* A_row = AData[r];
00289 double* ADB_row = ADBdata[r]-1;
00290 for (unsigned int c=0; c < nc1 ; ++c )
00291 {
00292 double ad = A_row[c] * d_data[c];
00293 if (ad==0.0) continue;
00294
00295 const double* B_row = BData[c]-1;
00296 int i = nc2+1;
00297 while (--i)
00298 {
00299 ADB_row[i] += ad * B_row[i];
00300 }
00301 }
00302 }
00303 }