/* This file contains functions to perform
 * statistical tests on the points sampled
 * from a Landau distribution.
 */

#include <stdio.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_min.h>
#include <gsl/gsl_roots.h>
#include <gsl/gsl_sum.h>

#include "landau.h"


/* Kolmogorov distribution CDF
 * for sample size n and statistic D
 */
double kolmogorov_cdf(double D, int n) {
  double x = sqrt(n) * D;

  // trick to reduce estimate error
  x += 1/(6 * sqrt(n)) + (x - 1)/(4 * n);

  // calculate the first n_terms of the series
  // Σ_k=1 exp(-(2k - 1)²π²/8x²)
  size_t n_terms = 30;
  double *terms = calloc(n_terms, sizeof(double));
  for (size_t k=0; k<n_terms; k++) {
    terms[k] = exp(-pow((2*(double)(k + 1) - 1)*M_PI/x, 2) / 8);
  }

  // do a transform to accelerate the convergence
  double sum, abserr;
  gsl_sum_levin_utrunc_workspace* s = gsl_sum_levin_utrunc_alloc(n_terms);
  gsl_sum_levin_utrunc_accel(terms, n_terms, s, &sum, &abserr);

  fprintf(stderr, "\n## Kolmogorov CDF\n");
  fprintf(stderr, "accel sum: %g\n", sum);
  fprintf(stderr, "plain sum: %g\n", s->sum_plain);
  fprintf(stderr, "err: %g\n",   abserr);

  gsl_sum_levin_utrunc_free(s);
  free(terms);

  return sqrt(2*M_PI)/x * sum;
}



/* This is a high-order function (ie a function that operates
 * on functions) in disguise. It takes a function f and produces
 * a function that computes -f. In lambda calculus it would be
 * the map λf.λx -f(x).
 *
 * Since there is no notion of lambda functions in C (the
 * standard one, at least) we use a trick involving the
 * gsl_function struct: `negate_func(x, fp)` takes a point `x`
 * and a gsl_function `fp` as the usual void pointer `params`.
 * It then calls the function `fp.function` with `x` and their
 * `fp.params` and return the negated result.
 *
 * So, given a `gsl_function f` its negated function is
 * contructed as follows:
 *
 *    gsl_function nf;
 *    nf.function = &negate_func;
 *    nf.params = &f;
 */
double negate_func(double x, void * fp) {
  gsl_function f = *((gsl_function*) fp);
  return -f.function(x, f.params);
}


/* Numerically computes the mode of a Landau
 * distribution by maximising the derivative.
 * The `min,max` parameters are the initial search
 * interval for the optimisation.
 *
 * If `err` is true print the estimate error.
 */
double numeric_mode(double min, double max,
                    gsl_function *pdf,
                    int err) {

  /* Negate the PDF to maximise it by
   * using a GSL minimisation method.
   * (There usually are no maximisation methods)
   */
  gsl_function npdf;
  npdf.function = &negate_func;
  npdf.params   = pdf;

  // initialize minimization
  double x = 0;
  int    max_iter = 100;
  double prec = 1e-7;
  int    status;
  const gsl_min_fminimizer_type *T = gsl_min_fminimizer_brent;
  gsl_min_fminimizer *s = gsl_min_fminimizer_alloc(T);
  gsl_min_fminimizer_set(s, &npdf, x, min, max);

  // minimisation
  for (int iter = 0; status == GSL_CONTINUE && iter < max_iter; iter++)
    {
    status = gsl_min_fminimizer_iterate(s);
    x      = gsl_min_fminimizer_x_minimum(s);
    min    = gsl_min_fminimizer_x_lower(s);
    max    = gsl_min_fminimizer_x_upper(s);
    status = gsl_min_test_interval(min, max, 0, prec);
    }

  /* The error is simply given by the width of
   * the final interval containing the solution
   */
  if (err)
    fprintf(stderr, "mode error: %.3g\n", max - min);

  // free memory
  gsl_min_fminimizer_free(s);
  return x;
}


/* A structure containing the half-maximum
 * of a PDF and a gsl_function of the PDF
 * itself, used by `numeric_fwhm` to compute
 * the FWHM.
 */
struct fwhm_params {
  double halfmax;
  gsl_function *pdf;
};


/* This is the implicit equation that is solved
 * in `numeric_fwhm` by a numerical root-finding
 * method. This function takes a point `x`, the
 * parameters defined in `fwhm_params` and returns
 * the value of:
 *
 *    f(x) - f(max)/2
 *
 * where f is the PDF of interest.
 */
double fwhm_equation(double x, void* params) {
  struct fwhm_params p = *((struct fwhm_params*) params);
  return p.pdf->function(x, p.pdf->params) - p.halfmax;
}


/* Numerically computes the FWHM of a PDF
 * distribution using the definition.
 * The `min,max` parameters are the initial search
 * interval for the root search of equation
 * `fwhm_equation`. Two searches are performed in
 * [min, mode] and [mode, max] that will give the
 * two solutions x₋, x₊. The FWHM is then x₊-x₋.
 *
 * If `err` is true print the estimate error.
 */
double numeric_fwhm(double min, double max,
                    gsl_function *pdf,
                    int err) {
  /* Create the gls_function structure
   * for the equation to be solved.
   */
  double mode = numeric_mode(min, max, pdf, 0);
  struct fwhm_params p = {
    pdf->function(mode, pdf->params)/2,
    pdf
  };
  gsl_function equation;
  equation.function = &fwhm_equation;
  equation.params   = &p;

  const gsl_root_fsolver_type *T = gsl_root_fsolver_brent;
  gsl_root_fsolver *s = gsl_root_fsolver_alloc(T);

  // initialize minimization for x₋
  double x, fmin, fmax;
  int    max_iter = 100;
  double prec = 1e-7;
  int    status;

  // minimization
  gsl_root_fsolver_set(s, &equation, min, mode);
  status = GSL_CONTINUE;
  for (int iter = 0; status == GSL_CONTINUE && iter < max_iter; iter++)
  { status = gsl_root_fsolver_iterate(s);
    x      = gsl_root_fsolver_root(s);
    fmin   = gsl_root_fsolver_x_lower(s);
    fmax   = gsl_root_fsolver_x_upper(s);
    status = gsl_min_test_interval(fmin, fmax, 0, prec);
  }
  double x_low = x;
  double err_low = fmax - fmin;

  // initialize minimization for x₊
  gsl_root_fsolver_set(s, &equation, mode, max);

  // minimization
  status = GSL_CONTINUE;
  for (int iter = 0; status == GSL_CONTINUE && iter < max_iter; iter++)
  { status = gsl_root_fsolver_iterate(s);
    x      = gsl_root_fsolver_root(s);
    fmin   = gsl_root_fsolver_x_lower(s);
    fmax   = gsl_root_fsolver_x_upper(s);
    status = gsl_min_test_interval(fmin, fmax, 0, prec);
  }
  double x_upp = x;
  double err_upp = fmax - fmin;

  if (err)
    fprintf(stderr, "fhwm error: %g\n",
            sqrt(err_low*err_low + err_upp*err_upp));

  // free memory
  gsl_root_fsolver_free(s);
  return x_upp - x_low;
}