contrib/rpl/rrel/rrel_irls.cxx
Go to the documentation of this file.
00001 // This is rpl/rrel/rrel_irls.cxx
00002 #include "rrel_irls.h"
00003 //:
00004 // \file
00005 
00006 #include <rrel/rrel_estimation_problem.h>
00007 #include <rrel/rrel_wls_obj.h>
00008 #include <rrel/rrel_util.h>
00009 
00010 #include <vnl/vnl_math.h>
00011 #include <vnl/vnl_vector.h>
00012 #include <vnl/vnl_matrix.h>
00013 
00014 #include <vcl_iostream.h>
00015 #include <vcl_vector.h>
00016 #include <vcl_cassert.h>
00017 
00018 const double rrel_irls::dflt_convergence_tol_ = 1e-4;
00019 const int rrel_irls::dflt_max_iterations_ = 25;
00020 const int rrel_irls::dflt_iterations_for_scale_ = 1;
00021 
00022 // -------------------------------------------------------------------------
00023 rrel_irls::rrel_irls( int max_iterations )
00024   : max_iterations_(max_iterations), test_converge_(true),
00025     convergence_tol_(dflt_convergence_tol_), est_scale_during_(true),
00026     iterations_for_scale_est_(dflt_iterations_for_scale_),
00027     scale_lower_bound_( -1.0 ),
00028     trace_level_(0), params_initialized_(false), scale_initialized_(false),
00029     obj_fcn_(1e256), prev_obj_fcn_(1e256),
00030     converged_(false), iteration_(0)
00031 {
00032   assert( max_iterations > 0 );
00033 }
00034 
00035 // -------------------------------------------------------------------------
00036 void
00037 rrel_irls::set_est_scale( int iterations_for_scale_est,
00038                           bool use_weighted_scale )
00039 {
00040   est_scale_during_ = true;
00041   use_weighted_scale_ = use_weighted_scale;
00042   iterations_for_scale_est_ = iterations_for_scale_est;
00043   if ( iterations_for_scale_est_ < 0 )
00044     vcl_cerr << "rrel_irls::est_scale_during WARNING last_scale_est_iter is\n"
00045              << "negative, so scale will not be estimated!\n";
00046 }
00047 
00048 // -------------------------------------------------------------------------
00049 //: Set lower bound of scale estimate
00050 void
00051 rrel_irls::set_scale_lower_bound( double lower_scale )
00052 {
00053   scale_lower_bound_ = lower_scale;
00054 }
00055 
00056 // -------------------------------------------------------------------------
00057 void
00058 rrel_irls::set_no_scale_est()
00059 {
00060   est_scale_during_ = false;
00061 }
00062 
00063 // -------------------------------------------------------------------------
00064 void
00065 rrel_irls::initialize_scale( double scale )
00066 {
00067   scale_ = scale;
00068   scale_initialized_ = true;
00069 }
00070 
00071 // -------------------------------------------------------------------------
00072 double
00073 rrel_irls::scale() const
00074 {
00075   assert( scale_initialized_ );
00076   return scale_;
00077 }
00078 
00079 
00080 // -------------------------------------------------------------------------
00081 void
00082 rrel_irls::set_max_iterations( int max_iterations )
00083 {
00084   max_iterations_ = max_iterations;
00085   assert( max_iterations_ > 0 );
00086 }
00087 
00088 
00089 // -------------------------------------------------------------------------
00090 void
00091 rrel_irls::set_convergence_test( double convergence_tol )
00092 {
00093   test_converge_ = true;
00094   convergence_tol_ = convergence_tol;
00095   assert( convergence_tol_ > 0 );
00096 }
00097 
00098 
00099 // -------------------------------------------------------------------------
00100 void
00101 rrel_irls::set_no_convergence_test( )
00102 {
00103   test_converge_ = false;
00104 }
00105 
00106 
00107 // -------------------------------------------------------------------------
00108 void
00109 rrel_irls::initialize_params( const vnl_vector<double>& init_params )
00110 {
00111   params_ = init_params;
00112   params_initialized_ = true;
00113 }
00114 
00115 
00116 bool
00117 rrel_irls::estimate( const rrel_estimation_problem* problem,
00118                      const rrel_wls_obj* obj )
00119 {
00120   iteration_ = 0;
00121   obj_fcn_ = 1e256;
00122   unsigned int num_for_fit = problem->num_samples_to_instantiate();
00123   bool allow_convergence_test = true;
00124   vcl_vector<double> residuals( problem->num_samples() );
00125   vcl_vector<double> weights( problem->num_samples() );
00126   bool failed = false;
00127 
00128   //  Parameter initialization, if necessary
00129   if ( ! params_initialized_ )
00130   {
00131     if ( ! problem->weighted_least_squares_fit( params_, cofact_ ) )
00132       return false;
00133     allow_convergence_test = false;
00134     params_initialized_ = true;
00135   }
00136 
00137 
00138   //  Scale initialization, if necessary
00139   if ( obj->requires_prior_scale() && problem->scale_type() == rrel_estimation_problem::NONE ) {
00140     vcl_cerr << "irls::estimate: Objective function requires a prior scale, and the problem does not provide one.\n"
00141              << "                Aborting estimation.\n";
00142     return false;
00143   } else {
00144     if ( problem->scale_type() == rrel_estimation_problem::NONE && ! scale_initialized_ ) {
00145       problem->compute_residuals( params_, residuals );
00146       scale_ = rrel_util_median_abs_dev_scale( residuals.begin(), residuals.end(), num_for_fit );
00147       allow_convergence_test = false;
00148       scale_initialized_ = true;
00149     }
00150   }
00151 
00152   if ( trace_level_ >= 1 )
00153     vcl_cout << "Initial estimate: " << params_ << ", scale = " << scale_ <<  vcl_endl;
00154 
00155   assert( params_initialized_ && scale_initialized_ );
00156   if ( scale_ <= 1e-8 ) {
00157     unsigned int dof = problem-> param_dof();
00158     cofact_ = 1e-16 * vnl_matrix<double>(dof, dof, vnl_matrix_identity);
00159     scale_ = 0.0;
00160     converged_ = true;
00161     vcl_cerr << "rrel_irls::estimate: initial scale is zero - cannot estimate\n";
00162     // usually, This means that it already has an exact fitting.
00163     // Thus, no harm if return true
00164     return true;
00165   }
00166 
00167 
00168   //  Basic loop:
00169   //  1. Calculate residuals
00170   //  2. Test for convergence, if desired.
00171   //  3. Calculate weights
00172   //  4. Calculate scale
00173   //  5. Calculate new estimate
00174   //
00175 
00176   converged_ = false;
00177   while ( true ) {
00178     //  Step 1.  Residuals
00179     problem->compute_residuals( params_, residuals );
00180     if ( trace_level_ >= 2 ) trace_residuals( residuals );
00181 
00182     //  Step 2.  Convergence.  The allow_convergence_test parameter
00183     //  prevents use of the convergence test until after the
00184     //  iterations involving scale estimation are finished.
00185     if ( test_converge_ && allow_convergence_test &&
00186          has_converged( residuals, obj, problem, &params_ ) ) {
00187       converged_ = true;
00188       break;
00189     }
00190     ++ iteration_;
00191     if ( iteration_ > max_iterations_ ) break;
00192     if ( trace_level_ >= 1 ) vcl_cout << "\nIteration: " << iteration_ << '\n';
00193 
00194     //  Step 3. Weights
00195     problem->compute_weights( residuals, obj, scale_, weights );
00196     if ( trace_level_ >= 2 ) trace_weights( weights );
00197 
00198     //  Step 4.  Scale.  Note: the residuals are reordered and therefore useless after
00199     //  rrel_util_median_abs_dev_scale.
00200     if ( est_scale_during_ && iteration_ <= iterations_for_scale_est_ ) {
00201       allow_convergence_test = false;
00202       if ( trace_level_ >= 1 ) vcl_cout << "num samples for fit = " << num_for_fit << '\n';
00203       if ( use_weighted_scale_ ) {
00204         assert( residuals.size() == weights.size() );
00205         scale_ = rrel_util_weighted_scale( residuals.begin(), residuals.end(),
00206                                            weights.begin(), num_for_fit, (double*)0 );
00207       }
00208       else {
00209         scale_ = rrel_util_median_abs_dev_scale( residuals.begin(), residuals.end(), num_for_fit );
00210       }
00211       if ( trace_level_ >= 1 ) vcl_cout << "Scale estimated: " << scale_ << vcl_endl;
00212       if ( scale_ <= 1e-8 ) {  //  fit exact enough to yield 0 scale estimate
00213         unsigned int dof = problem-> param_dof();
00214         cofact_ = 1e-16 * vnl_matrix<double>(dof, dof, vnl_matrix_identity);
00215         scale_ = 0.0;
00216         converged_ = true;
00217         vcl_cerr << "rrel_irls::estimate:  scale has gone to 0.\n";
00218         break;
00219       }
00220 
00221       // check lower bound
00222       if ( scale_lower_bound_ > 0 && scale_ < scale_lower_bound_ )
00223         scale_ = scale_lower_bound_;
00224     }
00225     else
00226       allow_convergence_test = true;
00227 
00228     // Step 5.  Weighted least-squares
00229     if ( !problem->weighted_least_squares_fit( params_, cofact_, &weights ) ) {
00230       failed = true;
00231       break;
00232     }
00233     if ( trace_level_ >= 1 ) vcl_cout << "Fit: " << params_ << vcl_endl;
00234   }
00235 
00236   return !failed;
00237 }
00238 
00239 
00240 // -------------------------------------------------------------------------
00241 const vnl_vector<double>&
00242 rrel_irls::params() const
00243 {
00244   assert( params_initialized_ );
00245   return params_;
00246 }
00247 
00248 
00249 // -------------------------------------------------------------------------
00250 const vnl_matrix<double>&
00251 rrel_irls::cofactor() const
00252 {
00253   assert( params_initialized_ );
00254   return cofact_;
00255 }
00256 
00257 
00258 // -------------------------------------------------------------------------
00259 int
00260 rrel_irls::iterations_used() const
00261 {
00262   return iteration_-1;
00263 }
00264 
00265 
00266 // -------------------------------------------------------------------------
00267 bool
00268 rrel_irls::has_converged( const vcl_vector<double>& residuals,
00269                           const rrel_wls_obj* obj,
00270                           const rrel_estimation_problem* problem,
00271                           vnl_vector<double>* params )
00272 {
00273   prev_obj_fcn_ = obj_fcn_;
00274   switch ( problem->scale_type() )
00275   {
00276    case rrel_estimation_problem::NONE:
00277     obj_fcn_ = obj->fcn( residuals.begin(), residuals.end(), scale_, params );
00278     break;
00279    case rrel_estimation_problem::SINGLE:
00280     obj_fcn_ = obj->fcn( residuals.begin(), residuals.end(), problem->prior_scale(), params );
00281     break;
00282    case rrel_estimation_problem::MULTIPLE:
00283     obj_fcn_ = obj->fcn( residuals.begin(), residuals.end(), problem->prior_multiple_scales().begin(), params );
00284     break;
00285    default:
00286     assert(!"invalid scale_type");
00287   }
00288 
00289   if ( trace_level_ >= 1 )
00290     vcl_cout << "  prev obj fcn = " << prev_obj_fcn_
00291              << ",  new obj fcn = " << obj_fcn_ << vcl_endl;
00292 
00293   return vnl_math_abs( obj_fcn_ ) < convergence_tol_  ||
00294     vnl_math_abs(obj_fcn_ - prev_obj_fcn_) < convergence_tol_ * obj_fcn_;
00295 }
00296 
00297 
00298 // -------------------------------------------------------------------------
00299 void
00300 rrel_irls::trace_residuals( const vcl_vector<double>& residuals ) const
00301 {
00302   vcl_cout << "Residuals:\n";
00303   for ( unsigned int i=0; i<residuals.size(); ++i )
00304     vcl_cout << "  " << i << ": " << residuals[i] << '\n';
00305 }
00306 
00307 
00308 // -------------------------------------------------------------------------
00309 void
00310 rrel_irls::trace_weights( const vcl_vector<double>& weights ) const
00311 {
00312   vcl_cout << "Weights:\n";
00313   for ( unsigned int i=0; i<weights.size(); ++i )
00314     vcl_cout << "  " << i << ": " << weights[i] << '\n';
00315 }
00316 
00317 
00318 // -------------------------------------------------------------------------
00319 void
00320 rrel_irls::print_params() const
00321 {
00322   vcl_cout << "  max_iterations_ = " << max_iterations_ << '\n'
00323            << "  test_converge_ = " << test_converge_ << '\n'
00324            << "  convergence_tol_ = " << convergence_tol_ << '\n'
00325            << "  est_scale_during_ = " << est_scale_during_ << '\n'
00326            << "  iterations_for_scale_est_ = " << iterations_for_scale_est_
00327            << vcl_endl;
00328 }