152 lines
4.3 KiB
C
152 lines
4.3 KiB
C
#include <stdio.h>
|
|
#include <string.h>
|
|
#include <math.h>
|
|
#include <gsl/gsl_blas.h>
|
|
#include <gsl/gsl_vector.h>
|
|
#include <gsl/gsl_rstat.h>
|
|
#include "common.h"
|
|
|
|
/* Options for the program */
|
|
struct options {
|
|
size_t nsig;
|
|
size_t nnoise;
|
|
int iter;
|
|
int trained;
|
|
double weights[3];
|
|
};
|
|
|
|
|
|
/* `classify(w, t, x)` classifies a single point
|
|
* `x` using the weight vector `w` and cut `t`.
|
|
* Returns 0 for noise and 1 for signal.
|
|
*/
|
|
int classify(gsl_vector *w, double cut, gsl_vector *x) {
|
|
double proj; gsl_blas_ddot(w, x, &proj);
|
|
return signbit(proj - cut);
|
|
}
|
|
|
|
|
|
/* Counts the misclassified data given the weights+cut
|
|
* values and the expected outcome using the `classify`
|
|
* function, defined above.
|
|
*/
|
|
int misclassified(double *weights, int expected, gsl_matrix *data) {
|
|
int count = 0;
|
|
double cut = weights[2];
|
|
gsl_vector w = gsl_vector_view_array(weights, 2).vector;
|
|
|
|
for (size_t i = 0; i < data->size1; i++) {
|
|
/* Get a vector view of the
|
|
* current row in the data matrix.
|
|
*/
|
|
gsl_vector x = gsl_matrix_const_row(data, i).vector;
|
|
if (classify(&w, cut, &x) != expected)
|
|
count++;
|
|
}
|
|
return count;
|
|
}
|
|
|
|
|
|
int show_help(char **argv) {
|
|
fprintf(stderr, "Usage: %s -[hisnw] WEIGHT [WEIGHT..] CUT\n", argv[0]);
|
|
fprintf(stderr, " -h\t\tShow this message.\n");
|
|
fprintf(stderr, " -i N\t\tThe number of test iterations to run. (default: 500)\n");
|
|
fprintf(stderr, " -s N\t\tThe number of events in signal class. (default: 800)\n");
|
|
fprintf(stderr, " -n N\t\tThe number of events in noise class. (default: 1000)\n");
|
|
fprintf(stderr, "\nRun tests classifying randomly generated data, "
|
|
"using the given WEIGHTs and CUTs.\n");
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
|
|
int main(int argc, char **argv) {
|
|
/* Set default options */
|
|
struct options opts;
|
|
opts.iter = 500;
|
|
opts.nsig = 800;
|
|
opts.nnoise = 1000;
|
|
|
|
/* Process CLI arguments */
|
|
if (argc < 2) return show_help(argv);
|
|
for (size_t i = 1; i < argc; i++) {
|
|
if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
|
|
else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
|
|
else if (!strcmp(argv[i], "-i")) opts.iter = atoi(argv[++i]);
|
|
else if (!strcmp(argv[i], "-h")) return show_help(argv);
|
|
else if (!strcmp(argv[i], "--help")) return show_help(argv);
|
|
else {
|
|
for (int j = 0; j < 3; j++)
|
|
opts.weights[j] = atof(argv[i++]);
|
|
}
|
|
}
|
|
|
|
// initialise RNG
|
|
gsl_rng_env_setup();
|
|
gsl_rng *r = gsl_rng_alloc(gsl_rng_default);
|
|
|
|
// initialise running stats
|
|
gsl_rstat_workspace *false_pos = gsl_rstat_alloc();
|
|
gsl_rstat_workspace *false_neg = gsl_rstat_alloc();
|
|
|
|
/* Generate two classes of normally
|
|
* distributed 2D points with different
|
|
* paramters: signal and noise.
|
|
*/
|
|
struct par par_sig = { 0, 0, 0.3, 0.3, 0.5 };
|
|
struct par par_noise = { 4, 4, 1.0, 1.0, 0.4 };
|
|
|
|
/* Generate `iter` different samples, apply the
|
|
* binary classification with the given weights
|
|
* and measure purity and efficiency.
|
|
*/
|
|
for (int i = 0; i < opts.iter; i++) {
|
|
sample_t *signal = generate_normal(r, opts.nsig, &par_sig);
|
|
sample_t *noise = generate_normal(r, opts.nnoise, &par_noise);
|
|
|
|
/* Count false positive/negatives and add them to
|
|
* the running stats workspaces to calculate mean/stddev.
|
|
*/
|
|
gsl_rstat_add(misclassified(opts.weights, 0, noise->data), false_pos);
|
|
gsl_rstat_add(misclassified(opts.weights, 1, signal->data), false_neg);
|
|
|
|
// free memory
|
|
sample_t_free(signal);
|
|
sample_t_free(noise);
|
|
}
|
|
|
|
// print out results
|
|
puts("# results\n");
|
|
puts("## false negatives");
|
|
printf(
|
|
"- total noise: %ld\n"
|
|
"- mean: %.3g, stddev: %.2g\n"
|
|
"- min: %g, max: %g\n",
|
|
opts.nnoise,
|
|
gsl_rstat_mean(false_neg),
|
|
gsl_rstat_sd(false_neg),
|
|
gsl_rstat_min(false_neg),
|
|
gsl_rstat_max(false_neg));
|
|
|
|
puts("\n## false positives");
|
|
printf(
|
|
"- total signal: %ld\n"
|
|
"- mean: %.3g, stddev: %.2g\n"
|
|
"- min: %g, max: %g\n",
|
|
opts.nsig,
|
|
gsl_rstat_mean(false_pos),
|
|
gsl_rstat_sd(false_pos),
|
|
gsl_rstat_min(false_pos),
|
|
gsl_rstat_max(false_pos));
|
|
|
|
puts("\n## averages");
|
|
printf("- purity:\t%.3e\n", 1 - gsl_rstat_mean(false_pos)/opts.nnoise);
|
|
printf("- efficiency:\t%.3e\n", 1 - gsl_rstat_mean(false_neg)/opts.nsig);
|
|
|
|
// free memory
|
|
gsl_rng_free(r);
|
|
gsl_rstat_free(false_neg);
|
|
gsl_rstat_free(false_pos);
|
|
|
|
return EXIT_SUCCESS;
|
|
}
|