contrib/mul/mbl/mbl_matrix_products.cxx
Go to the documentation of this file.
00001 //:
00002 // \file
00003 // \author Tim Cootes
00004 // \date 25-Apr-2001
00005 // \brief Various specialised versions of matrix product operations
00006 
00007 #include "mbl_matrix_products.h"
00008 #include <vnl/vnl_vector.h>
00009 #include <vnl/vnl_matrix.h>
00010 #include <vcl_cassert.h>
00011 #include <vcl_cstdlib.h> // for vcl_abort()
00012 #include <vcl_iostream.h>
00013 
00014 //=======================================================================
00015 //: Compute product AB = A * B
00016 //=======================================================================
00017 void mbl_matrix_product(vnl_matrix<double>& AB, const vnl_matrix<double>& A,
00018                         const vnl_matrix<double>& B)
00019 {
00020    unsigned int nr1 = A.rows();
00021    unsigned int nc1 = A.cols();
00022    unsigned int nr2 = B.rows();
00023    unsigned int nc2 = B.cols();
00024 
00025    if ( nr2 != nc1 )
00026    {
00027       vcl_cerr<<"Product : B.rows != A.cols\n";
00028       vcl_abort() ;
00029    }
00030 
00031    if ( (AB.rows()!=nr1) || (AB.cols()!= nc2) )
00032     AB.set_size( nr1, nc2 ) ;
00033 
00034   const double *const * AData = A.data_array();
00035   const double *const * BData = B.data_array();
00036   double ** RData = AB.data_array();
00037 
00038   // Zero the elements of AB
00039   AB.fill(0);
00040 
00041   for (unsigned int r=0; r < nr1; ++r)
00042   {
00043     const double* A_row = AData[r];
00044     double* R_row = RData[r]-1;
00045     for (unsigned int c=0; c < nc1 ; ++c )
00046     {
00047       double a = A_row[c];
00048       if (a==0.0) continue;
00049 
00050       const double* B_row = BData[c]-1;
00051       int i = nc2+1;
00052       while (--i)
00053       {
00054         R_row[i] += a * B_row[i];
00055       }
00056     }
00057   }
00058 }
00059 
00060 //=======================================================================
00061 //: Compute product ABt = A * B.transpose()
00062 //=======================================================================
00063 void mbl_matrix_product_a_bt(vnl_matrix<double>& ABt,
00064                              const vnl_matrix<double>& A,
00065                              const vnl_matrix<double>& B)
00066 {
00067   int nc1 = A.columns();
00068 #ifndef NDEBUG
00069   int nc2 = B.columns();
00070   if ( nc2 != nc1 )
00071   {
00072     vcl_cerr<<"mbl_matrix_product_a_bt : B.columns != A.columns\n";
00073     vcl_abort();
00074   }
00075 #endif //!NDEBUG
00076 
00077   mbl_matrix_product_a_bt(ABt,A,B,nc1);
00078 }
00079 
00080 //=======================================================================
00081 //: Compute product ABt = A * B.transpose(), using only nc cols of A and B
00082 //=======================================================================
00083 void mbl_matrix_product_a_bt(vnl_matrix<double>& ABt,
00084                              const vnl_matrix<double>& A,
00085                              const vnl_matrix<double>& B,
00086                              int nc)
00087 {
00088   unsigned int nr1 = A.rows();
00089   unsigned int nr2 = B.rows();
00090 
00091   assert(A.columns()>=(unsigned int)nc);
00092   assert(B.columns()>=(unsigned int)nc);
00093 
00094   if ( (ABt.rows()!=nr1) || (ABt.columns()!= nr2) )
00095     ABt.set_size( nr1, nr2 ) ;
00096 
00097   double const *const * A_data = A.data_array();
00098   double const *const * B_data = B.data_array();
00099   double ** R_data = ABt.data_array();
00100 
00101   for (unsigned int r=0;r<nr1;++r)
00102   {
00103     const double* A_row = A_data[r];
00104     double* R_row = R_data[r];
00105     for (unsigned int c=0;c<nr2;++c)
00106     {
00107       const double* B_row = B_data[c];
00108       R_row[c] = vnl_c_vector<double>::dot_product(A_row,B_row,nc);
00109     }
00110   }
00111 }
00112 
00113 //=======================================================================
00114 //: Compute product AtB = A.transpose() * B
00115 //=======================================================================
00116 void mbl_matrix_product_at_b(vnl_matrix<double>& AtB,
00117                              const vnl_matrix<double>& A,
00118                              const vnl_matrix<double>& B)
00119 {
00120   mbl_matrix_product_at_b(AtB,A,B,A.columns());
00121 }
00122 
00123 //=======================================================================
00124 //: Compute AAt = A * A.transpose(), using only first nr x nc partition of A
00125 //  Uses symmetry of result to improve speed
00126 //=======================================================================
00127 void mbl_matrix_product_a_at(vnl_matrix<double>& AAt,
00128                              const vnl_matrix<double>& A,
00129                              unsigned nr, unsigned nc)
00130 {
00131   assert(nr<=A.rows());
00132   assert(nc<=A.columns());
00133 
00134   if ( (AAt.rows()!=nr) || (AAt.columns()!= nr) )
00135     AAt.set_size( nr, nr ) ;
00136 
00137   double const *const * A_data = A.data_array();
00138   double ** R_data = AAt.data_array();
00139 
00140   // Fill in upper triangle of symmetric matrix
00141   for (unsigned int r=0;r<nr;++r)
00142   {
00143     const double* A_row = A_data[r];
00144     double* R_row = R_data[r];
00145     for (unsigned int c=r;c<nr;++c)
00146     {
00147       const double* B_row = A_data[c];
00148       R_row[c] = vnl_c_vector<double>::dot_product(A_row,B_row,nc);
00149     }
00150   }
00151 
00152   // Copy upper triangle to lower triangle
00153   for (unsigned int r=1;r<nr;++r)
00154     for (unsigned int c=0;c<r;++c)
00155       AAt(r,c)=AAt(c,r);
00156 }
00157 
00158 //=======================================================================
00159 //: Compute product AAt = A * A.transpose()
00160 //  Uses symmetry of result to be approx twice as fast as
00161 //  mbl_matrix_product_a_bt(AAt,A,A)
00162 //=======================================================================
00163 void mbl_matrix_product_a_at(vnl_matrix<double>& AAt,
00164                              const vnl_matrix<double>& A)
00165 {
00166   mbl_matrix_product_a_at(AAt,A,A.rows(),A.columns());
00167 }
00168 //=======================================================================
00169 //: Compute product AtB = A.transpose() * B, using nc_a cols of A
00170 //=======================================================================
00171 void mbl_matrix_product_at_b(vnl_matrix<double>& AtB,
00172                              const vnl_matrix<double>& A,
00173                              const vnl_matrix<double>& B,
00174                              int nc_a)
00175 {
00176   assert(nc_a >= 0 && A.columns()>=(unsigned int)nc_a);
00177   unsigned int nr1 = A.rows();
00178   unsigned int nr2 = B.rows();
00179   unsigned int nc2 = B.columns();
00180 
00181   if ( nr2 != nr1 )
00182   {
00183     vcl_cerr<<"TC_ProductAtB : B.rows != A.rows\n";
00184     vcl_abort();
00185   }
00186 
00187   if ( (AtB.rows()!=(unsigned int)nc_a) || (AtB.columns()!= nc2) )
00188     AtB.set_size( nc_a, nc2 ) ;
00189 
00190   double const *const * A_data = A.data_array();
00191   double const *const * B_data = B.data_array();
00192   double ** R_data = AtB.data_array()-1;
00193 
00194   AtB.fill(0);
00195 
00196   for (unsigned int r1 = 0; r1<nr1; ++r1)
00197   {
00198     const double* A_row = A_data[r1]-1;
00199     const double* B_row = B_data[r1]-1;
00200     double a;
00201     int c1 =  nc_a+1;
00202     while (--c1)
00203     {
00204       double *R_row = R_data[c1]-1;
00205       a = A_row[c1];
00206       int c2 = nc2+1;
00207       while (--c2)
00208       {
00209          R_row[c2] +=a*B_row[c2];
00210       }
00211     }
00212   }
00213 }
00214 
00215 
00216 //=======================================================================
00217 //: Compute product AtA = A.transpose() * A using nc cols of A
00218 //=======================================================================
00219 void mbl_matrix_product_at_a(vnl_matrix<double>& AtA,
00220                              const vnl_matrix<double>& A,
00221                              unsigned nc)
00222 {
00223   assert(A.columns()>=nc);
00224   unsigned int nr = A.rows();
00225 
00226   if ( AtA.rows()!=nr || (AtA.columns()!= nc) )
00227     AtA.set_size( nc, nc ) ;
00228 
00229   double const *const * A_data = A.data_array();
00230   double ** R_data = AtA.data_array()-1;
00231 
00232   AtA.fill(0);
00233 
00234   for (unsigned int r = 0; r<nr; ++r)
00235   {
00236     const double* A_row = A_data[r]-1;
00237     double a;
00238     int c1 =  nc+1;
00239     while (--c1)
00240     {
00241       double *R_row = R_data[c1]-1;
00242       a = A_row[c1];
00243       int c2 = nc+1;
00244       while (--c2)
00245       {
00246          R_row[c2] +=a*A_row[c2];
00247       }
00248     }
00249   }
00250 }
00251 
00252 //=======================================================================
00253 //: Compute product AtA = A.transpose() * A
00254 //=======================================================================
00255 void mbl_matrix_product_at_a(vnl_matrix<double>& AtA,
00256                              const vnl_matrix<double>& A)
00257 {
00258   mbl_matrix_product_at_a(AtA,A,A.columns());
00259 }
00260 
00261 //: Returns ADB = A * D * B
00262 //  where D is diagonal with elements d
00263 void mbl_matrix_product_adb(vnl_matrix<double>& ADB,
00264                             const vnl_matrix<double>& A,
00265                             const vnl_vector<double>& d,
00266                             const vnl_matrix<double>& B)
00267 {
00268   unsigned int nr1 = A.rows();
00269   unsigned int nc1 = A.cols();
00270   unsigned int nc2 = B.cols();
00271 
00272   assert ( B.rows() == nc1 ); //Product : B.nrows != A.ncols
00273 
00274   assert ( B.rows() == d.size() ); // Product : d.nelems != A.ncols
00275 
00276   if ( (ADB.rows()!=nr1) || (ADB.cols()!= nc2) )
00277     ADB.set_size( nr1, nc2 ) ;
00278 
00279   const double * const* AData = A.data_array();
00280   const double * const* BData = B.data_array();
00281   const double *  d_data = d.data_block();
00282   double ** ADBdata = ADB.data_array();
00283 
00284   ADB.fill(0);
00285 
00286   for (unsigned int r=0; r < nr1; ++r)
00287   {
00288     const double* A_row = AData[r];
00289     double* ADB_row = ADBdata[r]-1;
00290     for (unsigned int c=0; c < nc1 ; ++c )
00291     {
00292       double ad = A_row[c] * d_data[c];
00293       if (ad==0.0) continue;
00294 
00295       const double* B_row = BData[c]-1;
00296       int i = nc2+1;
00297       while (--i)
00298       {
00299         ADB_row[i] += ad * B_row[i];
00300       }
00301     }
00302   }
00303 }