analistica/ex-7/test.c

152 lines
4.3 KiB
C
Raw Permalink Normal View History

#include <stdio.h>
#include <string.h>
#include <math.h>
#include <gsl/gsl_blas.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_rstat.h>
#include "common.h"
/* Options for the program */
struct options {
size_t nsig;
size_t nnoise;
int iter;
int trained;
double weights[3];
};
/* `classify(w, t, x)` classifies a single point
* `x` using the weight vector `w` and cut `t`.
* Returns 0 for noise and 1 for signal.
*/
int classify(gsl_vector *w, double cut, gsl_vector *x) {
double proj; gsl_blas_ddot(w, x, &proj);
return signbit(proj - cut);
}
/* Counts the misclassified data given the weights+cut
* values and the expected outcome using the `classify`
* function, defined above.
*/
int misclassified(double *weights, int expected, gsl_matrix *data) {
int count = 0;
double cut = weights[2];
gsl_vector w = gsl_vector_view_array(weights, 2).vector;
for (size_t i = 0; i < data->size1; i++) {
/* Get a vector view of the
* current row in the data matrix.
*/
gsl_vector x = gsl_matrix_const_row(data, i).vector;
if (classify(&w, cut, &x) != expected)
count++;
}
return count;
}
int show_help(char **argv) {
fprintf(stderr, "Usage: %s -[hisnw] WEIGHT [WEIGHT..] CUT\n", argv[0]);
fprintf(stderr, " -h\t\tShow this message.\n");
2020-05-28 14:59:11 +02:00
fprintf(stderr, " -i N\t\tThe number of test iterations to run. (default: 500)\n");
fprintf(stderr, " -s N\t\tThe number of events in signal class. (default: 800)\n");
fprintf(stderr, " -n N\t\tThe number of events in noise class. (default: 1000)\n");
fprintf(stderr, "\nRun tests classifying randomly generated data, "
2020-04-07 23:38:24 +02:00
"using the given WEIGHTs and CUTs.\n");
return EXIT_FAILURE;
}
int main(int argc, char **argv) {
/* Set default options */
struct options opts;
2020-04-07 23:38:24 +02:00
opts.iter = 500;
opts.nsig = 800;
opts.nnoise = 1000;
/* Process CLI arguments */
if (argc < 2) return show_help(argv);
for (size_t i = 1; i < argc; i++) {
if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]);
else if (!strcmp(argv[i], "-h")) return show_help(argv);
2020-05-28 14:59:11 +02:00
else if (!strcmp(argv[i], "--help")) return show_help(argv);
else {
for (int j = 0; j < 3; j++)
opts.weights[j] = atof(argv[i++]);
}
}
// initialise RNG
gsl_rng_env_setup();
gsl_rng *r = gsl_rng_alloc(gsl_rng_default);
// initialise running stats
gsl_rstat_workspace *false_pos = gsl_rstat_alloc();
gsl_rstat_workspace *false_neg = gsl_rstat_alloc();
/* Generate two classes of normally
* distributed 2D points with different
* paramters: signal and noise.
*/
struct par par_sig = { 0, 0, 0.3, 0.3, 0.5 };
struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 };
/* Generate `iter` different samples, apply the
* binary classification with the given weights
* and measure purity and efficiency.
*/
for (int i = 0; i < opts.iter; i++) {
sample_t *signal = generate_normal(r, opts.nsig, &par_sig);
sample_t *noise = generate_normal(r, opts.nnoise, &par_noise);
/* Count false positive/negatives and add them to
* the running stats workspaces to calculate mean/stddev.
*/
2020-05-28 14:59:11 +02:00
gsl_rstat_add(misclassified(opts.weights, 0, noise->data), false_pos);
gsl_rstat_add(misclassified(opts.weights, 1, signal->data), false_neg);
// free memory
sample_t_free(signal);
sample_t_free(noise);
}
2020-05-28 14:59:11 +02:00
// print out results
puts("# results\n");
puts("## false negatives");
printf(
"- total noise: %ld\n"
"- mean: %.3g, stddev: %.2g\n"
"- min: %g, max: %g\n",
opts.nnoise,
gsl_rstat_mean(false_neg),
gsl_rstat_sd(false_neg),
gsl_rstat_min(false_neg),
gsl_rstat_max(false_neg));
puts("\n## false positives");
printf(
"- total signal: %ld\n"
"- mean: %.3g, stddev: %.2g\n"
"- min: %g, max: %g\n",
opts.nsig,
gsl_rstat_mean(false_pos),
gsl_rstat_sd(false_pos),
gsl_rstat_min(false_pos),
gsl_rstat_max(false_pos));
puts("\n## averages");
2020-05-28 14:59:11 +02:00
printf("- purity:\t%.3e\n", 1 - gsl_rstat_mean(false_pos)/opts.nnoise);
printf("- efficiency:\t%.3e\n", 1 - gsl_rstat_mean(false_neg)/opts.nsig);
// free memory
gsl_rng_free(r);
gsl_rstat_free(false_neg);
gsl_rstat_free(false_pos);
return EXIT_SUCCESS;
}