#include "fft.h"
#include <string.h>
#include <gsl/gsl_fft_halfcomplex.h>

/* These macros are used to extract the 
 * real and imaginary parts of a complex
 * vector */
#define REAL(a,stride,i) ((a)[2*(stride)*(i)])
#define IMAG(a,stride,i) ((a)[2*(stride)*(i)+1])


/* Computes the (forward) real DFT of a vector
 * and returns the results as a complex vector
 */
gsl_vector_complex* vector_fft(gsl_vector *data) {
  /* Create a copy of the input data to
   * preserve it: GSL functions work in-place.
   */
  gsl_vector *fdata = gsl_vector_alloc(data->size);
  gsl_vector_memcpy(fdata, data);

  /* Compute the forward Fourier transform */
  gsl_fft_real_wavetable *rtable =
    gsl_fft_real_wavetable_alloc(fdata->size);
  gsl_fft_real_workspace *wspace =
    gsl_fft_real_workspace_alloc(fdata->size);

  gsl_fft_real_transform(
    fdata->data,   // array of data to transform
    fdata->stride, // stride of the array, i.e. # steps between elements
    fdata->size,   // number of elements
    rtable,        // wavetable (real)
    wspace);       // workspace

  /* The data in `data` is in the half-complex
   * packed format, which reduces memory usage
   * by using the fact that
   * 
   *   z_i = z̅_(n-i)
   *
   * We unpack it to the normal packing and use
   * the array to define a gsl_vector_complex.
   */
  gsl_vector_complex *cdata = gsl_vector_complex_calloc(fdata->size);
  gsl_fft_halfcomplex_unpack(
      fdata->data,
      cdata->data,
      cdata->stride,
      fdata->size);

  // clear memory
  gsl_fft_real_wavetable_free(rtable);
  gsl_fft_real_workspace_free(wspace);
  gsl_vector_free(fdata);

  return cdata;
}


/* Inverse function of gsl_fft_halfcomplex_unpack.
 * `gsl_fft_halfcomplex_pack(c, hc, stride, n)`
 * creates a halfcomplex packed version of a
 * complex packaged array `hc` with stride `stride`
 * and `n` elements.
 */
int gsl_fft_halfcomplex_pack(
  const gsl_complex_packed_array c,
  double hc[],
  const size_t stride, const size_t n) {

  hc[0] = REAL(c, stride, 0);

  size_t i;
  for (i = 1; i < n - i; i++) {
    hc[(2 * i - 1) * stride] = REAL(c, stride, i);
    hc[2 * i * stride]       = IMAG(c, stride, i);
  }

  if (i == n - i)
    hc[(n - 1) * stride] = REAL(c, stride, i);

  return 0;
}


/* `fft_deconvolve(data, kernel)` tries to deconvolve
 * `kernel` from `data` by factoring out the `kernel`
 * in the Fourier space.
 */
gsl_histogram* fft_deconvolve(
  gsl_histogram *data,
  gsl_histogram *kernel) {
  /* Size of the original data
   * before being convoluted.
   *
   * Notation of sizes:
   *  - original: m
   *  - kernel: n
   *  - "full" convolution: m + n - 1
   */
  size_t orig_size = data->n - kernel->n + 1;

  /* Create vector views */
  gsl_vector vdata =
    gsl_vector_view_array(data->bin, data->n).vector;
  gsl_vector vkernel =
    gsl_vector_view_array(kernel->bin, kernel->n).vector;

  /* 0-pad the kernel to make it the
   * same size as the convoluted data,
   * which is m+n-1:
   *
   * 1. create a new vector
   * 2. create a subvector view to the center
   * 3. copy the kernel to the view
   */
  gsl_vector *vpadded = gsl_vector_calloc(data->n);

  /* Copy the first half (origin + positive values)
   * into the leftmost side of the padded kernel vector.
   */
  gsl_vector pad_left = gsl_vector_subvector(
    vpadded,                   // source
    0,                         // offset
    kernel->n/2 + 1).vector;   // size
  gsl_vector ker_left = gsl_vector_subvector(
    &vkernel,
    kernel->n/2,
    kernel->n/2 + 1).vector;
  gsl_vector_memcpy(&pad_left, &ker_left);

  /* Copy the second half (negative values)
   * into the rightmost side of the padded kernel vector.
   */
  gsl_vector pad_right = gsl_vector_subvector(
    vpadded,                   // source
    data->n - kernel->n/2,     // offset
    (kernel->n/2)).vector;     // size
  gsl_vector ker_right = gsl_vector_subvector(
    &vkernel,
    0,
    kernel->n/2).vector;
  gsl_vector_memcpy(&pad_right, &ker_right);

  /* Compute the DFT of the data and
   * divide it by the DFT of the kernel.
   */
  gsl_vector_complex *fdata   = vector_fft(&vdata);
  gsl_vector_complex *fkernel = vector_fft(vpadded);
  gsl_vector_complex_div(fdata, fkernel);

  /* Pack the result in the half-complex
   * format before computing the inverse DFT
   */
  double *res = calloc(fdata->size, sizeof(double));
  gsl_fft_halfcomplex_pack(fdata->data, res,
                           fdata->stride, fdata->size);

  /* Compute the inverse DFT of the result */
  gsl_fft_halfcomplex_wavetable *htable =
    gsl_fft_halfcomplex_wavetable_alloc(data->n);
  gsl_fft_real_workspace *wspace =
    gsl_fft_real_workspace_alloc(data->n);
  gsl_fft_halfcomplex_inverse(
    res,           // array of data to transform
    fdata->stride, // stride of the array
    fdata->size,   // number of elements
    htable,        // wavetable (complex)
    wspace);       // workspace

  /* Create a histogram with the same edges
   * as `data`, but with the original size,
   * to return the cleaned result
   */
  gsl_histogram *hist = gsl_histogram_calloc(orig_size);

  /* Set the same bin edges as `data`*/
  double max = gsl_histogram_max(data);
  double min = gsl_histogram_min(data);
  gsl_histogram_set_ranges_uniform(hist, min, max);

  /* Remove the extra size of the convolution
   * (where the overlap is partial) when copying
   * over the result to the histogram.
   * This amounts to (n-1)/2 on each end of the
   * vector.
   */
  memcpy(
    hist->bin,
    res + (kernel->n - 1)/2,
    orig_size * sizeof(double));

  // free memory
  gsl_vector_free(vpadded);
  gsl_vector_complex_free(fdata);
  gsl_vector_complex_free(fkernel);
  gsl_fft_halfcomplex_wavetable_free(htable);
  gsl_fft_real_workspace_free(wspace);
  free(res);

  return hist;
}