#include "common.h"
#include "chisquared.h"
#include <gsl/gsl_multifit_nlinear.h>
#include <gsl/gsl_blas.h>

/* Minimisation function
 *
 * The χ² is defined as
 * 
 * χ² = |f|²
 *    = Σi fi²
 *
 * where fi = (Oi - Ei)/√Ei are the residuals.
 *
 * This function takes the parameters (α,β,γ) as
 * a gsl_vector `par`, the sample `data` and a
 * gsl_vector `f`, that will store the results
 * `fi`.
 */
int chisquared_f(const gsl_vector *par,
                 void *data_, gsl_vector *f) {
  struct hist data = *((struct hist*) data_);

  struct event event;
  double max, min;
  double delta_th, delta_ph;
  double expected, observed;

  /* Loop over the (i,j) bins of the 2D histogram.
   * The index k of the k-th component of f is actually:
   *
   *   k = i * φbins + j
   *
   * Note: this is called the C-order.
   */

  for (size_t i = 0; i < data.h->nx; i++)
    for (size_t j = 0; j < data.h->ny; j++) {
      /* Get the bin ranges and width for θ,φ.
       * The event is supposed to happen in the
       * midpoint of the range.
       */
      gsl_histogram2d_get_xrange(data.h, i, &max, &min);
      event.th = (max + min)/2;
      delta_th = (max - min);
      gsl_histogram2d_get_yrange(data.h, j, &max, &min);
      event.ph = (max + min)/2;
      delta_ph = (max - min);

      /* O = observed number of events in the k-th bin
       * E = expected number of events in k-th bin
       *
       * E is given by:
       *
       *        /---> total number of events
       *       /
       *   E = N  F(α,β,γ; θ,φ) Δθ Δφ sin(θ)
       *         \__________________________/
       *                      |
       *                      `-> probability
       */
      observed = gsl_histogram2d_get(data.h, i, j);
      expected = data.n * distr(par, &event)
               * delta_th * delta_ph * sin(event.th);

      /* Return an error (invalid domain) if
       * the current bin is empty. That would
       * be division by zero.
       */
      if (expected < 0) {
        //fprintf(stderr, "[warning] bin %ld:%ld p<0 (%.2g, %.2g, %.2g)\n",
        //        i, j, gsl_vector_get(par, 0),
        //              gsl_vector_get(par, 1),
        //              gsl_vector_get(par, 2));
        expected = 1e-6;
      }

      gsl_vector_set(
        f, i * data.h->ny + j,
        (observed - expected)/sqrt(expected));
    }

  return GSL_SUCCESS; 
}


/* Jacobian function
 *
 * The gradient of the χ² function is:
 *
 *   ∇χ² = ∇ |f|²
 *       = 2 J^T⋅f
 *
 * where J is the jacobian of the (vector)
 * function f. Its entries are given by
 *
 *  2 Jij = 2 ∂fi/∂xj
 *        = -(Oi + Ei)/√Ei 1/Fi (∇Fi)j
 *
 */
int chisquared_df(const gsl_vector *par,
                  void *data_, gsl_matrix *jac) {

  struct hist data = *((struct hist*) data_);

  struct event event;
  double max, min;
  double delta_th, delta_ph, prob;
  double expected, observed;
  gsl_vector *grad = gsl_vector_alloc(3);

  for (size_t i = 0; i < data.h->nx; i++)
    for (size_t j = 0; j < data.h->ny; j++) {
      /* Get the bin ranges for θ,φ.
       * The event is supposed to happen in the
       * midpoint of the range.
       */
      gsl_histogram2d_get_xrange(data.h, i, &max, &min);
      event.th = (max + min)/2;
      delta_th = (max - min);
      gsl_histogram2d_get_yrange(data.h, j, &max, &min);
      event.ph = (max + min)/2;
      delta_ph = (max - min);

      prob = distr(par, &event);
      observed = gsl_histogram2d_get(data.h, i, j);
      expected = data.n * prob * delta_th * delta_ph * sin(event.th);

      if (expected < 0)
        expected = 1e-6;
      if (observed == 0)
        observed = 1;

      /* Compute the gradient of F(α,β,γ; θi,φi),
       * then rescale it and set it to the i-th row
       * of the jacobian.
       */
      grad_distr(par, &event, grad);
      gsl_vector_scale(
          grad,
          -0.5*(observed + expected)/sqrt(expected) * 1/prob);
      gsl_matrix_set_row(jac, i * data.h->ny + j, grad);

    }

  // free memory
  gsl_vector_free(grad);

  return GSL_SUCCESS;
}


/* This is a callback function called during
 * the minimisation to show the current progress.
 * It prints:
 *   1. the condition number cond(J) of the jacobian
 *   2. the reduced χ² value
 *   3. the current parameters
 */
void callback(const size_t iter, void *params,
              const gsl_multifit_nlinear_workspace *w) {
  gsl_vector *f = gsl_multifit_nlinear_residual(w);
  gsl_vector *x = gsl_multifit_nlinear_position(w);

  //if (iter % 4 != 0)
  //  return;

  /* Compute the condition number of the
   * jacobian and the reduced χ² (χ²/d).
   */
  double rcond, chi2;
  int d = w->fdf->n - w->fdf->p;
  gsl_multifit_nlinear_rcond(&rcond, w);
  gsl_blas_ddot(f, f, &chi2);

  fprintf(stderr, "%2ld\t", iter);
  vector_fprint(stderr, x);
  fprintf(
    stderr, "\tcond(J)=%.4g, χ²/d=%.4g\n\n",
    1.0/rcond,
    chi2/d);
}


/* Minimum χ² estimation of of the parameters
 * α,β,γ.
 */
min_result minChiSquared(struct hist data) {
  /* Initialise the function to be minimised */
  gsl_multifit_nlinear_fdf chisquared;
  chisquared.f      = chisquared_f;             // function
  chisquared.df     = chisquared_df;            // gradient
  chisquared.fvv    = NULL;                     // something for geodesic accel.
  chisquared.n      = data.h->nx * data.h->ny;  // numeber of data points
  chisquared.p      = 3;                        // numeber of data points
  chisquared.params = &data;                    // histogram data

  /* Initialise the minimisation workspace */
  gsl_multifit_nlinear_parameters options =
    gsl_multifit_nlinear_default_parameters();

  gsl_multifit_nlinear_workspace *w = 
    gsl_multifit_nlinear_alloc(
      gsl_multifit_nlinear_trust,  // minimisation method
      &options,                    // minimisation options
      data.h->nx * data.h->ny,     // number of data points
      3);                          // number of (unknown) params

  /* Set the starting point of the 
   * minimisation.
   */
  gsl_vector *par = gsl_vector_alloc(3);
  gsl_vector_set(par, 0,  0.50);  // α
  gsl_vector_set(par, 1,  0.01);  // β
  gsl_vector_set(par, 2, -0.10);  // γ
  gsl_multifit_nlinear_init(par, &chisquared, w);

  /* Configure the solver and run the minimisation
   * using the high-level driver.
   */
  fputs("\n# least χ²\n\n", stderr);
  int status, info;
  status = gsl_multifit_nlinear_driver(
    100,      // max number of iterations
    1e-8,     // tolerance for step test: |δ| ≤ xtol(|x| + xtol)
    1e-8,     // tolerance for gradient test: max |∇i⋅ xi| ≤ gtol⋅xi
    1e-8,     // tolerance for norm test
    callback, // function called on each iteration
    NULL,     // callback parameters 
    &info,    // stores convergence information
    w);       // minimisation workspace

  fprintf(stderr, "status: %s\n", gsl_strerror(status));
  if (status != GSL_SUCCESS)
    fprintf(stderr, "info: %s\n", gsl_strerror(info));
  fprintf(stderr, "iterations: %ld\n", gsl_multifit_nlinear_niter(w));

  /* Store results in the min_result type.
   * Note: We allocate a new vector/matrix for the
   * parameters because `w->x` will be freed
   * along with the workspace `w`.
   */
  min_result res;
  res.par = gsl_vector_alloc(3);
  res.err = gsl_vector_alloc(3);
  res.cov = gsl_matrix_alloc(3, 3);
  gsl_vector_memcpy(res.par, gsl_multifit_nlinear_position(w));

  /* Compute the covariance of the fit parameters.
   * The covariance Σ is estimated by
   *
   *   Σ = H⁻¹
   *
   * where H is the hessian of the χ², which is
   * itself approximated by the jacobian as
   *
   *   H = J^T⋅J
   *
   */
  gsl_matrix *jac = gsl_multifit_nlinear_jac(w);
  gsl_multifit_nlinear_covar(jac, 0.0, res.cov);

  /* Compute the standard errors
   * from the covariance matrix.
   */
  gsl_vector_const_view diagonal = gsl_matrix_const_diagonal(res.cov);
  gsl_vector_memcpy(res.err, &diagonal.vector);
  vector_map(&sqrt, res.err);

  /* Compute the reduced χ² */
  double chi2;
  gsl_vector *f = gsl_multifit_nlinear_residual(w);
  gsl_blas_ddot(f, f, &chi2);
  chi2 /= chisquared.n - chisquared.p;

  /* Print the results */
  fputs("\n## results\n", stderr);

  fprintf(stderr, "\n* χ²/d: %.3f\n", chi2);

  fputs("\n* parameters:\n    ", stderr);
  vector_fprint(stderr, res.par);

  fputs("\n* covariance:\n", stderr);
  matrix_fprint(stderr, res.cov);

  // free memory
  gsl_multifit_nlinear_free(w);
  gsl_vector_free(par);

  return res;
}