#include <math.h>

#include "common.h"
#include "rl.h"


/* Performs the Richardson-Lucy deconvolution.
 * In pseudo-python:
 *
 *   def rl_deconvolve(data, kernel, rounds):
 *     est = np.full(data, 0.5)
 *     for _ in range(rounds):
 *       est_conv = convolve(est, kernel)
 *       est *= convolve(data / est_conv, reversed)
 *
 */
gsl_histogram* rl_deconvolve(
  gsl_histogram *data,
  gsl_histogram *kernel,
  size_t rounds) {
  /* Size of the original data
   * before being convoluted.
   *
   * Notation of sizes:
   *  - original: m
   *  - kernel: n
   *  - "full" convolution: m + n - 1
   */
  size_t orig_size = data->n - kernel->n + 1;

  /* Create a histogram with the same edges
   * as `data`, but with the original size,
   * in which to return the cleaned result.
   */
  gsl_histogram *hist = gsl_histogram_calloc(orig_size);

  /* Set the same bin edges as `data`*/
  double max = gsl_histogram_max(data);
  double min = gsl_histogram_min(data);
  gsl_histogram_set_ranges_uniform(hist, min, max);

  /* Vector views of the result, kernel
   * and data. These are used to perform
   * vectorised operations on histograms.
   */
  gsl_vector est =
    gsl_vector_view_array(hist->bin, hist->n).vector;
  gsl_vector vkernel =
    gsl_vector_view_array(kernel->bin, kernel->n).vector;
  gsl_vector vdata =
    gsl_vector_view_array(data->bin, data->n).vector;
  gsl_vector center;

  /* Create a flipped copy of the kernel */
  gsl_vector *vkernel_flip = gsl_vector_alloc(kernel->n);
  gsl_vector_memcpy(vkernel_flip, &vkernel);
  gsl_vector_reverse(vkernel_flip);

  /* More vectors to store partial
   * results
   */
  gsl_vector* est_conv = gsl_vector_alloc(data->n);
  gsl_vector* rel_blur = gsl_vector_alloc(data->n);

  /* The zero-order estimate is simply
   * all elements at 0.5 */
  gsl_vector_set_all(&est, 0.5);

  for (size_t iter = 0; iter < rounds; iter++) {
    /* The current estimated convolution is the
     * current estimate of the data with
     * the kernel */
    gsl_vector_convolve(&est, &vkernel, est_conv);

    /* Divide the data by the estimated
     * convolution to calculate the "relative blur".
     */
    gsl_vector_memcpy(rel_blur, &vdata);
    gsl_vector_div(rel_blur, est_conv);

    /* Set NaNs to zero */
    for (size_t i = 0; i < rel_blur->size; i++)
      if (isnan(gsl_vector_get(rel_blur, i))) {
        double y = (i > 0)? gsl_vector_get(rel_blur, i - 1) : 1;
        gsl_vector_set(rel_blur, i, y); 
      }

    /* Convolve the blur by the kernel
     * and multiply the current estimate
     * of the data by it.
     */
    center = gsl_vector_subvector(rel_blur, (kernel->n-1)/2, orig_size).vector;
    gsl_vector_convolve(&center, vkernel_flip, est_conv);
    center = gsl_vector_subvector(est_conv, (kernel->n-1)/2, orig_size).vector;
    gsl_vector_mul(&est, &center);
  }

  // free memory
  gsl_vector_free(est_conv);
  gsl_vector_free(rel_blur);
  gsl_vector_free(vkernel_flip);

  return hist;

}