#include "likelihood.h"
#include <gsl/gsl_errno.h>
#include <gsl/gsl_multimin.h>
#include <gsl/gsl_linalg.h>

/* `f_logL(par, &sample)`
 * gives the negative log-likelihood for the sample `sample`.
 */
double f_logL(const gsl_vector *par, void *sample_) {
  struct sample sample = *((struct sample*) sample_);
  double sum = 0;

  for (size_t i = 0; i < sample.size; i++)
    sum += log(fabs(distr(par, &sample.events[i])));

  /* Rescale the function to avoid
   * round-off errors during minimisation
   */
  sum *= 1e-5;

  return -sum;
}


/* `fdf_logL(par, &sample, fun, grad)`
 * simultaneously computes the gradient and the value of
 * the negative log-likelihood for the sample `sample`
 * and stores the results in `fun` and `grad` respectively.
 */
void fdf_logL(const gsl_vector *par, void *sample_,
              double *fun, gsl_vector *grad) {
  struct sample sample = *((struct sample*) sample_);

  double prob;
  struct event *event;
  gsl_vector *term = gsl_vector_alloc(3);

  // Note: fun/grad are *not* guaranteed to be 0
  gsl_vector_set_zero(grad);
  *fun = 0;

  for (size_t i = 0; i < sample.size; i++) {
    event = &sample.events[i];
    prob = distr(par, event);

    /* The gradient of log(F(α,β,γ)) is:
     * 1/F(α,β,γ) ∇F(α,β,γ)
     */
    grad_distr(par, event, term);
    gsl_vector_scale(term, -1.0/prob);
    gsl_vector_add(grad, term);

    // compute function
    *fun += log(fabs(prob));
  }

  /* Rescale the function to avoid
   * round-off errors during minimisation
   */
  *fun = -*fun*1e-5;
  gsl_vector_scale(grad, 1e-5);

  // free memory
  gsl_vector_free(term);
}


/* `df_logL(par, &sample, grad)`
 * gives the gradient of the negative log-likelihood
 * for the sample `sample`. The result is stored in
 * the `grad` gsl_vector.
 */
void df_logL(const gsl_vector *par, void *sample, gsl_vector *grad) {
  double fun;
  fdf_logL(par, sample, &fun, grad);
}


/* `hf_logL(par, &sample, hess)`
 * gives the hessian matrix of the negative log-likelihood
 * for the sample `sample`. The result is stored in
 * the `hess` gsl_matrix.
 */
gsl_matrix* hf_logL(const gsl_vector *par, void *sample_) {
  struct sample sample = *((struct sample*) sample_);

  gsl_vector *grad = gsl_vector_calloc(3);
  gsl_matrix *term = gsl_matrix_calloc(3, 3);
  gsl_matrix *hess = gsl_matrix_calloc(3, 3);

  double prob;
  struct event *event;
  for (size_t n = 0; n < sample.size; n++) {
    /* Compute gradient and value of F */
    event = &sample.events[n];
    prob = distr(par, event);
    grad_distr(par, event, grad);

    /* Compute the hessian matrix of -log(F):
     * H_ij = 1/F² ∇F_i ∇F_j
     */
    for (size_t i = 0; i < 3; i++)
    for (size_t j = 0; j < 3; j++) {
      gsl_matrix_set(
        term, i, j,
        gsl_vector_get(grad, i) * gsl_vector_get(grad, j));
    }
    gsl_matrix_scale(term, 1.0/pow(prob, 2));
    gsl_matrix_add(hess, term);
  }

  // free memory
  gsl_vector_free(grad);
  gsl_matrix_free(term);

  return hess;
}


/* Maximum likelihood estimation of the parameters
 * α,β,γ.
 */
min_result maxLikelihood(struct sample *sample) {
  /* Starting point: α,β,γ = "something close
   * to the solution", because the minimisation
   * seems pretty unstable.
   */
  gsl_vector *par = gsl_vector_alloc(3);
  gsl_vector_set(par, 0, 0.79);
  gsl_vector_set(par, 1, 0.02);
  gsl_vector_set(par, 2, -0.17);

  /* Initialise the minimisation */
  gsl_multimin_function_fdf likelihood;
  likelihood.n      = 3;
  likelihood.f      = f_logL;
  likelihood.df     = df_logL;
  likelihood.fdf    = fdf_logL;
  likelihood.params = sample;

  /* The minimisation technique used
   * is the conjugate gradient method. */
  const gsl_multimin_fdfminimizer_type *type =
    gsl_multimin_fdfminimizer_conjugate_fr;

  gsl_multimin_fdfminimizer *m = 
    gsl_multimin_fdfminimizer_alloc(type, 3);

  gsl_multimin_fdfminimizer_set(
    m,            // minimisation technique
    &likelihood,  // function to minimise
    par,          // initial parameters/starting point
    1e-5,         // first step length
    0.1);         // accuracy

  size_t iter;
  int status = GSL_CONTINUE;

  /* Iterate the minimisation until the gradient
   * is smaller that 10^-4 or fail.
   */
  fputs("\n# maximum likelihood\n\n", stderr);
  for (iter = 0; status == GSL_CONTINUE && iter < 100; iter++) {
    status = gsl_multimin_fdfminimizer_iterate(m);

    if (status)
      break;

    status = gsl_multimin_test_gradient(m->gradient, 1e-5);

    /* Print the iteration, current parameters
     * and gradient of the log-likelihood.
     */
    if (iter % 5 == 0 || status == GSL_SUCCESS) {
      fprintf(stderr, "%2ld\t", iter);
      vector_fprint(stderr, m->x);
      fputs("\t", stderr);
      vector_fprint(stderr, m->gradient);
      putc('\n', stderr);
    }
  }
  fprintf(stderr, "status: %s\n", gsl_strerror(status));
  fprintf(stderr, "iterations: %ld\n", iter);

  /* Store results in the min_result type.
   * Note: We allocate a new vector for the
   * parameters because `m->x` will be freed
   * along with m.
   */
  min_result res;
  res.par = gsl_vector_alloc(3);
  res.err = gsl_vector_alloc(3);
  res.cov = gsl_matrix_alloc(3, 3);
  gsl_vector_memcpy(res.par, m->x);

  /* Compute the standard errors of the estimates.
   * The Cramér–Rao bound states that the covariance
   * matrix of the estimates is
   * 
   *   Σ ≥ I⁻¹ = (-H)⁻¹
   *
   * where:
   *   - I is called the Fisher information,
   *   - H is the hessian of log(L) at
   *     the maximum likelihood.
   */
  res.cov = hf_logL(m->x, sample);

  /* Invert H by Cholesky decomposition:
   *
   * H = LL^T ⇒ H⁻¹ = L^T⁻¹L⁻¹
   *
   * Note: H is positive defined (because -logL is
   * minimised) and symmetric so this is valid.
   *
   */
  gsl_linalg_cholesky_decomp(res.cov);
  gsl_linalg_cholesky_invert(res.cov);

  /* Compute the standard errors
   * from the covariance matrix.
   */
  gsl_vector_const_view diagonal = gsl_matrix_const_diagonal(res.cov);
  gsl_vector_memcpy(res.err, &diagonal.vector);
  vector_map(&sqrt, res.err);

  fputs("\n## results\n", stderr);

  fputs("\n* parameters:\n    ", stderr);
  vector_fprint(stderr, res.par);

  fputs("\n* covariance (Cramér–Rao lower bound):\n", stderr);
  matrix_fprint(stderr, res.cov);

  // free memory
  gsl_multimin_fdfminimizer_free(m);

  return res;
}