diff --git a/ex-7/test.c b/ex-7/test.c new file mode 100644 index 0000000..c325d3d --- /dev/null +++ b/ex-7/test.c @@ -0,0 +1,149 @@ +#include +#include +#include +#include +#include +#include +#include "common.h" + +/* Options for the program */ +struct options { + size_t nsig; + size_t nnoise; + int iter; + int trained; + double weights[3]; +}; + + +/* `classify(w, t, x)` classifies a single point + * `x` using the weight vector `w` and cut `t`. + * Returns 0 for noise and 1 for signal. + */ +int classify(gsl_vector *w, double cut, gsl_vector *x) { + double proj; gsl_blas_ddot(w, x, &proj); + return signbit(proj - cut); +} + + +/* Counts the misclassified data given the weights+cut + * values and the expected outcome using the `classify` + * function, defined above. + */ +int misclassified(double *weights, int expected, gsl_matrix *data) { + int count = 0; + double cut = weights[2]; + gsl_vector w = gsl_vector_view_array(weights, 2).vector; + + for (size_t i = 0; i < data->size1; i++) { + /* Get a vector view of the + * current row in the data matrix. + */ + gsl_vector x = gsl_matrix_const_row(data, i).vector; + if (classify(&w, cut, &x) != expected) + count++; + } + return count; +} + + +int show_help(char **argv) { + fprintf(stderr, "Usage: %s -[hisnw] WEIGHT [WEIGHT..] CUT\n", argv[0]); + fprintf(stderr, " -h\t\tShow this message.\n"); + fprintf(stderr, " -i N\t\tThe number of test iterations to run.\n"); + fprintf(stderr, " -s N\t\tThe number of events in signal class.\n"); + fprintf(stderr, " -n N\t\tThe number of events in noise class.\n"); + fprintf(stderr, "\nRun tests classifying randomly generated data, " + "using the given WEIGHT and CUTs.\n"); + return EXIT_FAILURE; +} + + +int main(int argc, char **argv) { + /* Set default options */ + struct options opts; + opts.nsig = 800; + opts.nnoise = 1000; + opts.iter = 500; + + /* Process CLI arguments */ + if (argc < 2) return show_help(argv); + for (size_t i = 1; i < argc; i++) { + if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]); + else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]); + else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]); + else if (!strcmp(argv[i], "-h")) return show_help(argv); + else { + for (int j = 0; j < 3; j++) + opts.weights[j] = atof(argv[i++]); + } + } + + // initialise RNG + gsl_rng_env_setup(); + gsl_rng *r = gsl_rng_alloc(gsl_rng_default); + + // initialise running stats + gsl_rstat_workspace *false_pos = gsl_rstat_alloc(); + gsl_rstat_workspace *false_neg = gsl_rstat_alloc(); + + /* Generate two classes of normally + * distributed 2D points with different + * paramters: signal and noise. + */ + struct par par_sig = { 0, 0, 0.3, 0.3, 0.5 }; + struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 }; + + /* Generate `iter` different samples, apply the + * binary classification with the given weights + * and measure purity and efficiency. + */ + for (int i = 0; i < opts.iter; i++) { + sample_t *signal = generate_normal(r, opts.nsig, &par_sig); + sample_t *noise = generate_normal(r, opts.nnoise, &par_noise); + + /* Count false positive/negatives and add them to + * the running stats workspaces to calculate mean/stddev. + */ + gsl_rstat_add(misclassified(opts.weights, 0, noise->data), false_neg); + gsl_rstat_add(misclassified(opts.weights, 1, signal->data), false_pos); + + // free memory + sample_t_free(signal); + sample_t_free(noise); + } + + puts("# results\n"); + puts("## false negatives"); + printf( + "- total noise: %ld\n" + "- mean: %.3g, stddev: %.2g\n" + "- min: %g, max: %g\n", + opts.nnoise, + gsl_rstat_mean(false_neg), + gsl_rstat_sd(false_neg), + gsl_rstat_min(false_neg), + gsl_rstat_max(false_neg)); + + puts("\n## false positives"); + printf( + "- total signal: %ld\n" + "- mean: %.3g, stddev: %.2g\n" + "- min: %g, max: %g\n", + opts.nsig, + gsl_rstat_mean(false_pos), + gsl_rstat_sd(false_pos), + gsl_rstat_min(false_pos), + gsl_rstat_max(false_pos)); + + puts("\n## averages"); + printf("- purity:\t%.3e\n", 1 - gsl_rstat_mean(false_neg)/opts.nsig); + printf("- efficiency:\t%.3e\n", 1 - gsl_rstat_mean(false_pos)/opts.nsig); + + // free memory + gsl_rng_free(r); + gsl_rstat_free(false_neg); + gsl_rstat_free(false_pos); + + return EXIT_SUCCESS; +} diff --git a/makefile b/makefile index c0e3626..0cbd4ef 100644 --- a/makefile +++ b/makefile @@ -26,6 +26,8 @@ ex-6/bin/main: ex-6/rl.c ex-6/fft.c ex-7/bin/main: ex-7/main.c ex-7/common.c ex-7/fisher.c ex-7/percep.c $(CCOMPILE) +ex-7/bin/test: ex-7/test.c ex-7/common.c + $(CCOMPILE) misc/pdfs: misc/pdfs.c