//
// source code: least_square.cc 
// purpose:     least square(s) fit
// author:      T. Hebbeker   
// version:     1.1, 2006-11-06
//
//

#include <iostream>
#include <fstream>      // for file handling
#include <string>       // for string handling
#include <iomanip>      // for formatting printout
#include <math.h>
using namespace std;

// define global variables (accessible from main and all functions) 

const int ndata_max     = 7; // maximum number of (x,y) pairs 
                                //    allowed
      int ndata         = 0;
      double x_data[ndata_max + 1];       
      double y_data[ndata_max + 1]; 

// function declarations  (functions follow after main)

int read_input_file(string);  
int least_square_fit(double &, double &);
void print_fit_results(double,double);

int main()
{

// define local variables (validity restricted to main) 
// can also be defined later inside main, just where they are used.

  double a   = -999.;
  double b   = - 99.;
  int result = -1;
  string input_file;

  cout << endl 
       << " program   === least_square ===   version 1.0 " 
       << endl;

  cout << endl
       << " data input file should have the following format: " 
       << endl
       << "   x1      y1 " << endl  
       << "   x2      y3 " << endl  
       << "  ...     ... " << endl  
       << "   xn      yn " << endl << endl; 

  cout << " fit result: parameters a and b in y = a * x + b " 
       << endl;

  cout << endl
       << " enter input file name: ";
  cin >> input_file;
  cout << " input file = " << input_file << endl; 

  result = read_input_file(input_file);

  if(1 == result)
  {
    cout << " ERROR: missing or bad input file = " 
         << input_file << endl;
    return 1;  
  }
  else if (2 == result)
  {
    cout << " WARNING: too many input data pairs = " 
         << ndata << endl;
    cout << "          only first " <<  ndata_max 
         << " values used in fit " << endl; 
    ndata = ndata_max;
  }

  result = least_square_fit(a,b);

  if(0 == result) 
  {
    print_fit_results(a,b);
  }
  else if (9 == result)
  {
    cout << " WARNING: too few input data pairs = " << ndata << endl;
    cout << "          fit not possible " << endl;
  }
  else if (99 == result)
  {
    cout << " ERROR: numerical problem encountered in fit " << endl;
    return 2;  
  }

  return 0;

}


int read_input_file(string filename)
{
  double x_file;    // x value found on file 
  double y_file;    // y value found on file 

  ndata = 0; 

// define input stream "in", open input file 

  ifstream in(filename.c_str()); // this can be understood only later...
  if (!in)
  {
    return 1;
  }

// read all data and store numbers in arrays

  while(1) 
  {
    in >> x_file >> y_file;
    if (in.eof()) break;
    ndata++;
    if(ndata <= ndata_max) 
    {
      x_data[ndata] = x_file;
      y_data[ndata] = y_file;
    }
  }

  if (ndata>ndata_max)
  {
    return 2;
  }

// close input stream

   in.close();

   return 0;
}
 

void print_fit_results(double a_in,double b_in)
{
  cout << endl
       << " Fit result: " << endl 
       << "              # points = " << setw(9) << ndata << endl              
       << "              slope  a = " << setw(9) << a_in << endl               
       << "              offset b = " << setw(9) << b_in << endl;
}


int least_square_fit(double &a_out, double &b_out)
{

// define local variables (are reinitialized for each function call !)

  double sum_x  = 0.;  
  double sum_y  = 0.;
  double sum_x2 = 0.;
  double sum_y2 = 0.;
  double sum_xy = 0.;

  if(ndata <=2) 
  {
    return 9;
  }

  for (int k=1; k<=ndata; k++)
  {
    sum_x  += x_data[k];
    sum_x2 += x_data[k]*x_data[k];
    sum_y  += y_data[k];
    sum_y2 += y_data[k]*y_data[k];
    sum_xy += x_data[k]*y_data[k];
  } 

  const double epsilon = 1.E-9;
  double denominator = ndata*sum_x2-sum_x*sum_x;


  // note: in the following we do not simply write epsilon, 
  //       but use epsilon * ndata * sum_x2
  //       so that we have a positive reference number
  //       (ndata * sum_x2) which scales with the size of x 
  //       and ndata the same way as denominator does !    

  if(fabs(denominator)<epsilon*ndata*sum_x2)   
  {                         // important: use fabs and not abs
    return 99;
  }
  else
  {
    a_out = (ndata*sum_xy - sum_x*sum_y)/denominator;
    b_out = (sum_x2*sum_y - sum_x*sum_xy)/denominator;
  }

  return 0;

}



