ex-7: add test program to measure purity/efficiency
This commit is contained in:
parent
1bcbf1cfe6
commit
6d4e263c34
149
ex-7/test.c
Normal file
149
ex-7/test.c
Normal file
@ -0,0 +1,149 @@
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include <gsl/gsl_blas.h>
|
||||
#include <gsl/gsl_vector.h>
|
||||
#include <gsl/gsl_rstat.h>
|
||||
#include "common.h"
|
||||
|
||||
/* Options for the program */
|
||||
struct options {
|
||||
size_t nsig;
|
||||
size_t nnoise;
|
||||
int iter;
|
||||
int trained;
|
||||
double weights[3];
|
||||
};
|
||||
|
||||
|
||||
/* `classify(w, t, x)` classifies a single point
|
||||
* `x` using the weight vector `w` and cut `t`.
|
||||
* Returns 0 for noise and 1 for signal.
|
||||
*/
|
||||
int classify(gsl_vector *w, double cut, gsl_vector *x) {
|
||||
double proj; gsl_blas_ddot(w, x, &proj);
|
||||
return signbit(proj - cut);
|
||||
}
|
||||
|
||||
|
||||
/* Counts the misclassified data given the weights+cut
|
||||
* values and the expected outcome using the `classify`
|
||||
* function, defined above.
|
||||
*/
|
||||
int misclassified(double *weights, int expected, gsl_matrix *data) {
|
||||
int count = 0;
|
||||
double cut = weights[2];
|
||||
gsl_vector w = gsl_vector_view_array(weights, 2).vector;
|
||||
|
||||
for (size_t i = 0; i < data->size1; i++) {
|
||||
/* Get a vector view of the
|
||||
* current row in the data matrix.
|
||||
*/
|
||||
gsl_vector x = gsl_matrix_const_row(data, i).vector;
|
||||
if (classify(&w, cut, &x) != expected)
|
||||
count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
int show_help(char **argv) {
|
||||
fprintf(stderr, "Usage: %s -[hisnw] WEIGHT [WEIGHT..] CUT\n", argv[0]);
|
||||
fprintf(stderr, " -h\t\tShow this message.\n");
|
||||
fprintf(stderr, " -i N\t\tThe number of test iterations to run.\n");
|
||||
fprintf(stderr, " -s N\t\tThe number of events in signal class.\n");
|
||||
fprintf(stderr, " -n N\t\tThe number of events in noise class.\n");
|
||||
fprintf(stderr, "\nRun tests classifying randomly generated data, "
|
||||
"using the given WEIGHT and CUTs.\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
/* Set default options */
|
||||
struct options opts;
|
||||
opts.nsig = 800;
|
||||
opts.nnoise = 1000;
|
||||
opts.iter = 500;
|
||||
|
||||
/* Process CLI arguments */
|
||||
if (argc < 2) return show_help(argv);
|
||||
for (size_t i = 1; i < argc; i++) {
|
||||
if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
|
||||
else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
|
||||
else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]);
|
||||
else if (!strcmp(argv[i], "-h")) return show_help(argv);
|
||||
else {
|
||||
for (int j = 0; j < 3; j++)
|
||||
opts.weights[j] = atof(argv[i++]);
|
||||
}
|
||||
}
|
||||
|
||||
// initialise RNG
|
||||
gsl_rng_env_setup();
|
||||
gsl_rng *r = gsl_rng_alloc(gsl_rng_default);
|
||||
|
||||
// initialise running stats
|
||||
gsl_rstat_workspace *false_pos = gsl_rstat_alloc();
|
||||
gsl_rstat_workspace *false_neg = gsl_rstat_alloc();
|
||||
|
||||
/* Generate two classes of normally
|
||||
* distributed 2D points with different
|
||||
* paramters: signal and noise.
|
||||
*/
|
||||
struct par par_sig = { 0, 0, 0.3, 0.3, 0.5 };
|
||||
struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 };
|
||||
|
||||
/* Generate `iter` different samples, apply the
|
||||
* binary classification with the given weights
|
||||
* and measure purity and efficiency.
|
||||
*/
|
||||
for (int i = 0; i < opts.iter; i++) {
|
||||
sample_t *signal = generate_normal(r, opts.nsig, &par_sig);
|
||||
sample_t *noise = generate_normal(r, opts.nnoise, &par_noise);
|
||||
|
||||
/* Count false positive/negatives and add them to
|
||||
* the running stats workspaces to calculate mean/stddev.
|
||||
*/
|
||||
gsl_rstat_add(misclassified(opts.weights, 0, noise->data), false_neg);
|
||||
gsl_rstat_add(misclassified(opts.weights, 1, signal->data), false_pos);
|
||||
|
||||
// free memory
|
||||
sample_t_free(signal);
|
||||
sample_t_free(noise);
|
||||
}
|
||||
|
||||
puts("# results\n");
|
||||
puts("## false negatives");
|
||||
printf(
|
||||
"- total noise: %ld\n"
|
||||
"- mean: %.3g, stddev: %.2g\n"
|
||||
"- min: %g, max: %g\n",
|
||||
opts.nnoise,
|
||||
gsl_rstat_mean(false_neg),
|
||||
gsl_rstat_sd(false_neg),
|
||||
gsl_rstat_min(false_neg),
|
||||
gsl_rstat_max(false_neg));
|
||||
|
||||
puts("\n## false positives");
|
||||
printf(
|
||||
"- total signal: %ld\n"
|
||||
"- mean: %.3g, stddev: %.2g\n"
|
||||
"- min: %g, max: %g\n",
|
||||
opts.nsig,
|
||||
gsl_rstat_mean(false_pos),
|
||||
gsl_rstat_sd(false_pos),
|
||||
gsl_rstat_min(false_pos),
|
||||
gsl_rstat_max(false_pos));
|
||||
|
||||
puts("\n## averages");
|
||||
printf("- purity:\t%.3e\n", 1 - gsl_rstat_mean(false_neg)/opts.nsig);
|
||||
printf("- efficiency:\t%.3e\n", 1 - gsl_rstat_mean(false_pos)/opts.nsig);
|
||||
|
||||
// free memory
|
||||
gsl_rng_free(r);
|
||||
gsl_rstat_free(false_neg);
|
||||
gsl_rstat_free(false_pos);
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
Loading…
Reference in New Issue
Block a user