#include <stdio.h>
#include <string.h>
#include "fisher.h"
#include "percep.h"

/* Options for the program */
struct options {
  char *mode;
  size_t nsig;
  size_t nnoise;
  int iter;
  int trained;
  double weights[3];
};


int main(int argc, char **argv) {
  /* Set default options */
  struct options opts;
  opts.mode    = "fisher";
  opts.nsig    =  800;
  opts.nnoise  = 1000;
  opts.iter    = 5;
  opts.trained = 0;

  /* Process CLI arguments */
  for (size_t i = 1; i < argc; i++) {
         if (!strcmp(argv[i], "-m")) opts.mode = argv[++i];
    else if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
    else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
    else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]);
    else if (!strcmp(argv[i], "-w")) {
      opts.trained = 1;
      for (int j = 0; j < 3; j++)
        opts.weights[j] = atof(argv[++i]);
    }
    else {
      fprintf(stderr, "Usage: %s -[hmisnw]\n", argv[0]);
      fprintf(stderr, "  -h\t\tShow this message.\n");
      fprintf(stderr, "  -m MODE\tThe training mode to use: 'fisher' for \n\t\t"
                      "Fisher linear discriminant, 'percep' for perceptron. (default: 'fisher')\n");
      fprintf(stderr, "  -i N\t\tThe number of training iterations "
                       "(for perceptron only). (default: 5)\n");
      fprintf(stderr, "  -s N\t\tThe number of events in signal class. (default: 800)\n");
      fprintf(stderr, "  -n N\t\tThe number of events in noise class. (default: 1000)\n");
      fprintf(stderr, "  -w W₁ W₂ B\tSet weights and bias (if pre-trained).\n");
      return EXIT_FAILURE;
    }
  }

  // initialize RNG
  gsl_rng_env_setup();
  gsl_rng *r = gsl_rng_alloc(gsl_rng_default);

  /* Generate two classes of normally
   * distributed 2D points with different
   * paramters: signal and noise.
   */
  struct par par_sig   = { 0, 0, 0.3, 0.3, 0.5 };
  struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 };

  sample_t *signal = generate_normal(r, opts.nsig, &par_sig);
  sample_t *noise  = generate_normal(r, opts.nnoise, &par_noise);

  gsl_vector *w;
  double cut;

  if (opts.trained) {
    /* Pre-trained
     *
     * Use the provided weigths and bias.
     */
    fputs("# Pre-trained \n\n", stderr);
    gsl_vector_view v = gsl_vector_view_array(opts.weights, 2);
    w = &v.vector;
    cut = opts.weights[2];
  }
  else if (!strcmp(opts.mode, "fisher")) {
    /* Fisher linear discriminant
     *
     * First calculate the direction w onto
     * which project the data points. Then the
     * cut which determines the class for each
     * projected point.
     */
    fputs("# Linear Fisher discriminant\n\n", stderr);
    double ratio = opts.nsig / (double)opts.nnoise;
    w = fisher_proj(signal, noise);
    cut = fisher_cut(ratio, w, signal, noise);
  }
  else if (!strcmp(opts.mode, "percep")) {
    /* Perceptron
     *
     * Train a single perceptron on the
     * dataset to get an approximate
     * solution in `iter` iterations.
     */
  //  fputs("# Perceptron \n\n", stderr);
    w = percep_train(signal, noise, opts.iter, &cut);
  }
  else {
    fputs("\n\nerror: invalid mode. select either"
          " 'fisher' or 'percep'\n", stderr);
    return EXIT_FAILURE;
  }

  /* Print the results of the method
   * selected: weights and threshold.
   */
  fprintf(stderr, "\n* i: %d\n", opts.iter);
  fprintf(stderr, "* w: [%.3f, %.3f]\n",
      gsl_vector_get(w, 0),
      gsl_vector_get(w, 1));
  fprintf(stderr, "* cut: %.3f\n", cut);
  gsl_vector_fprintf(stdout, w, "%g");
  printf("%f\n", cut);

  /* Print data to stdout for plotting.
   * Note: we print the sizes to be able
   * to set apart the two matrices.
   */
  printf("%ld %ld %d\n", opts.nsig, opts.nnoise, 2);
  gsl_matrix_fprintf(stdout, signal->data, "%g");
  gsl_matrix_fprintf(stdout, noise->data,  "%g");

  // free memory
  gsl_rng_free(r);
  sample_t_free(signal);
  sample_t_free(noise);

  return EXIT_SUCCESS;
}