#include <stdio.h>
#include <string.h>
#include <math.h>
#include <gsl/gsl_blas.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_rstat.h>
#include "common.h"

/* Options for the program */
struct options {
  size_t nsig;
  size_t nnoise;
  int iter;
  int trained;
  double weights[3];
};


/* `classify(w, t, x)` classifies a single point
 * `x` using the weight vector `w` and cut `t`.
 * Returns 0 for noise and 1 for signal.
 */
int classify(gsl_vector *w, double cut, gsl_vector *x) {
  double proj; gsl_blas_ddot(w, x, &proj);
  return signbit(proj - cut);
}


/* Counts the misclassified data given the weights+cut
 * values and the expected outcome using the `classify`
 * function, defined above.
 */
int misclassified(double *weights, int expected, gsl_matrix *data) {
  int count = 0;
  double cut = weights[2];
  gsl_vector w = gsl_vector_view_array(weights, 2).vector;

  for (size_t i = 0; i < data->size1; i++) {
    /* Get a vector view of the 
     * current row in the data matrix.
     */
    gsl_vector x = gsl_matrix_const_row(data, i).vector;
    if (classify(&w, cut, &x) != expected)
      count++;
  }
  return count;
}


int show_help(char **argv) {
  fprintf(stderr, "Usage: %s -[hisnw] WEIGHT [WEIGHT..] CUT\n", argv[0]);
  fprintf(stderr, "  -h\t\tShow this message.\n");
  fprintf(stderr, "  -i N\t\tThe number of test iterations to run.\n");
  fprintf(stderr, "  -s N\t\tThe number of events in signal class.\n");
  fprintf(stderr, "  -n N\t\tThe number of events in noise class.\n");
  fprintf(stderr, "\nRun tests classifying randomly generated data, "
                  "using the given WEIGHTs and CUTs.\n");
  return EXIT_FAILURE;
}


int main(int argc, char **argv) {
  /* Set default options */
  struct options opts;
  opts.iter    = 500;
  opts.nsig    =  800;
  opts.nnoise  = 1000;

  /* Process CLI arguments */
  if (argc < 2) return show_help(argv);
  for (size_t i = 1; i < argc; i++) {
         if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
    else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
    else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]);
    else if (!strcmp(argv[i], "-h")) return show_help(argv);
    else {
      for (int j = 0; j < 3; j++)
        opts.weights[j] = atof(argv[i++]);
    }
  }

  // initialise RNG
  gsl_rng_env_setup();
  gsl_rng *r = gsl_rng_alloc(gsl_rng_default);

  // initialise running stats
  gsl_rstat_workspace *false_pos = gsl_rstat_alloc();
  gsl_rstat_workspace *false_neg = gsl_rstat_alloc();

  /* Generate two classes of normally
   * distributed 2D points with different
   * paramters: signal and noise.
   */
  struct par par_sig   = { 0, 0, 0.3, 0.3, 0.5 };
  struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 };

  /* Generate `iter` different samples, apply the
   * binary classification with the given weights
   * and measure purity and efficiency.
   */
  for (int i = 0; i < opts.iter; i++) {
    sample_t *signal = generate_normal(r, opts.nsig, &par_sig);
    sample_t *noise  = generate_normal(r, opts.nnoise, &par_noise);

    /* Count false positive/negatives and add them to
     * the running stats workspaces to calculate mean/stddev.
     */
    gsl_rstat_add(misclassified(opts.weights, 0,  noise->data), false_neg);
    gsl_rstat_add(misclassified(opts.weights, 1, signal->data), false_pos);
  
    // free memory
    sample_t_free(signal);
    sample_t_free(noise);
  }

  puts("# results\n");
  puts("## false negatives");
  printf(
    "- total noise: %ld\n"
    "- mean: %.3g, stddev: %.2g\n"
    "- min: %g, max: %g\n",
    opts.nnoise,
    gsl_rstat_mean(false_neg),
    gsl_rstat_sd(false_neg),
    gsl_rstat_min(false_neg),
    gsl_rstat_max(false_neg));

  puts("\n## false positives");
  printf(
    "- total signal: %ld\n"
    "- mean: %.3g, stddev: %.2g\n"
    "- min: %g, max: %g\n",
    opts.nsig,
    gsl_rstat_mean(false_pos),
    gsl_rstat_sd(false_pos),
    gsl_rstat_min(false_pos),
    gsl_rstat_max(false_pos));

  puts("\n## averages");
  printf("- purity:\t%.3e\n", 1 - gsl_rstat_mean(false_neg)/opts.nsig);
  printf("- efficiency:\t%.3e\n", 1 - gsl_rstat_mean(false_pos)/opts.nsig);

  // free memory
  gsl_rng_free(r);
  gsl_rstat_free(false_neg);
  gsl_rstat_free(false_pos);

  return EXIT_SUCCESS;
}