analistica/ex-7/main.c

133 lines
3.9 KiB
C
Raw Normal View History

2020-03-06 02:24:32 +01:00
#include <stdio.h>
2020-03-06 11:54:28 +01:00
#include <string.h>
#include "fisher.h"
2020-03-06 19:46:42 +01:00
#include "percep.h"
2020-03-06 11:54:28 +01:00
/* Options for the program */
struct options {
char *mode;
size_t nsig;
size_t nnoise;
2020-03-06 19:46:42 +01:00
int iter;
2020-03-07 02:50:09 +01:00
int trained;
double weights[3];
2020-03-06 02:24:32 +01:00
};
2020-03-06 11:54:28 +01:00
int main(int argc, char **argv) {
/* Set default options */
struct options opts;
2020-03-07 02:50:09 +01:00
opts.mode = "fisher";
opts.nsig = 800;
opts.nnoise = 1000;
opts.iter = 5;
opts.trained = 0;
2020-03-06 11:54:28 +01:00
/* Process CLI arguments */
for (size_t i = 1; i < argc; i++) {
if (!strcmp(argv[i], "-m")) opts.mode = argv[++i];
else if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
2020-03-07 18:47:55 +01:00
else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]);
2020-03-07 02:50:09 +01:00
else if (!strcmp(argv[i], "-w")) {
opts.trained = 1;
for (int j = 0; j < 3; j++)
opts.weights[j] = atof(argv[++i]);
}
2020-03-06 11:54:28 +01:00
else {
2020-03-07 18:47:55 +01:00
fprintf(stderr, "Usage: %s -[hmisnw]\n", argv[0]);
2020-03-07 02:50:09 +01:00
fprintf(stderr, " -h\t\tShow this message.\n");
fprintf(stderr, " -m MODE\tThe training mode to use: 'fisher' for \n\t\t"
"Fisher linear discriminant, 'percep' for perceptron. (default: 'fisher')\n");
2020-03-07 02:50:09 +01:00
fprintf(stderr, " -i N\t\tThe number of training iterations "
"(for perceptron only). (default: 5)\n");
fprintf(stderr, " -s N\t\tThe number of events in signal class. (default: 800)\n");
fprintf(stderr, " -n N\t\tThe number of events in noise class. (default: 1000)\n");
2020-03-07 02:50:09 +01:00
fprintf(stderr, " -w W₁ W₂ B\tSet weights and bias (if pre-trained).\n");
2020-03-06 11:54:28 +01:00
return EXIT_FAILURE;
}
2020-03-06 02:24:32 +01:00
}
// initialize RNG
gsl_rng_env_setup();
gsl_rng *r = gsl_rng_alloc(gsl_rng_default);
/* Generate two classes of normally
* distributed 2D points with different
* paramters: signal and noise.
*/
struct par par_sig = { 0, 0, 0.3, 0.3, 0.5 };
struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 };
2020-03-06 11:54:28 +01:00
sample_t *signal = generate_normal(r, opts.nsig, &par_sig);
sample_t *noise = generate_normal(r, opts.nnoise, &par_noise);
2020-03-06 02:24:32 +01:00
2020-03-06 19:46:42 +01:00
gsl_vector *w;
2020-03-07 02:50:09 +01:00
double cut;
2020-03-06 19:46:42 +01:00
2020-03-07 02:50:09 +01:00
if (opts.trained) {
/* Pre-trained
*
* Use the provided weigths and bias.
*/
fputs("# Pre-trained \n\n", stderr);
gsl_vector_view v = gsl_vector_view_array(opts.weights, 2);
w = &v.vector;
cut = opts.weights[2];
}
else if (!strcmp(opts.mode, "fisher")) {
2020-03-06 11:54:28 +01:00
/* Fisher linear discriminant
*
* First calculate the direction w onto
* which project the data points. Then the
* cut which determines the class for each
* projected point.
*/
fputs("# Linear Fisher discriminant\n\n", stderr);
2020-03-06 19:46:42 +01:00
double ratio = opts.nsig / (double)opts.nnoise;
w = fisher_proj(signal, noise);
2020-03-07 02:50:09 +01:00
cut = fisher_cut(ratio, w, signal, noise);
2020-03-06 11:54:28 +01:00
}
2020-03-06 19:46:42 +01:00
else if (!strcmp(opts.mode, "percep")) {
/* Perceptron
*
* Train a single perceptron on the
* dataset to get an approximate
* solution in `iter` iterations.
*/
// fputs("# Perceptron \n\n", stderr);
2020-03-07 02:50:09 +01:00
w = percep_train(signal, noise, opts.iter, &cut);
2020-03-06 19:46:42 +01:00
}
else {
fputs("\n\nerror: invalid mode. select either"
" 'fisher' or 'percep'\n", stderr);
return EXIT_FAILURE;
}
/* Print the results of the method
* selected: weights and threshold.
*/
fprintf(stderr, "\n* i: %d\n", opts.iter);
2020-03-06 19:46:42 +01:00
fprintf(stderr, "* w: [%.3f, %.3f]\n",
gsl_vector_get(w, 0),
gsl_vector_get(w, 1));
2020-03-07 02:50:09 +01:00
fprintf(stderr, "* cut: %.3f\n", cut);
2020-05-28 14:59:11 +02:00
gsl_vector_fprintf(stdout, w, "%g");
printf("%f\n", cut);
2020-03-06 02:24:32 +01:00
/* Print data to stdout for plotting.
2020-03-06 11:54:28 +01:00
* Note: we print the sizes to be able
2020-03-06 02:24:32 +01:00
* to set apart the two matrices.
*/
2020-05-28 14:59:11 +02:00
printf("%ld %ld %d\n", opts.nsig, opts.nnoise, 2);
gsl_matrix_fprintf(stdout, signal->data, "%g");
gsl_matrix_fprintf(stdout, noise->data, "%g");
2020-03-06 02:24:32 +01:00
// free memory
gsl_rng_free(r);
sample_t_free(signal);
sample_t_free(noise);
return EXIT_SUCCESS;
}