ex-7: add pre-trained mode
This commit is contained in:
parent
313363c707
commit
7daefc590a
50
ex-7/main.c
50
ex-7/main.c
@ -9,16 +9,19 @@ struct options {
|
|||||||
size_t nsig;
|
size_t nsig;
|
||||||
size_t nnoise;
|
size_t nnoise;
|
||||||
int iter;
|
int iter;
|
||||||
|
int trained;
|
||||||
|
double weights[3];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
/* Set default options */
|
/* Set default options */
|
||||||
struct options opts;
|
struct options opts;
|
||||||
opts.mode = "fisher";
|
opts.mode = "fisher";
|
||||||
opts.nsig = 800;
|
opts.nsig = 800;
|
||||||
opts.nnoise = 1000;
|
opts.nnoise = 1000;
|
||||||
opts.iter = 5;
|
opts.iter = 5;
|
||||||
|
opts.trained = 0;
|
||||||
|
|
||||||
/* Process CLI arguments */
|
/* Process CLI arguments */
|
||||||
for (size_t i = 1; i < argc; i++) {
|
for (size_t i = 1; i < argc; i++) {
|
||||||
@ -26,13 +29,20 @@ int main(int argc, char **argv) {
|
|||||||
else if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
|
else if (!strcmp(argv[i], "-s")) opts.nsig = atol(argv[++i]);
|
||||||
else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
|
else if (!strcmp(argv[i], "-n")) opts.nnoise = atol(argv[++i]);
|
||||||
else if (!strcmp(argv[i], "-i")) opts.nnoise = atoi(argv[++i]);
|
else if (!strcmp(argv[i], "-i")) opts.nnoise = atoi(argv[++i]);
|
||||||
|
else if (!strcmp(argv[i], "-w")) {
|
||||||
|
opts.trained = 1;
|
||||||
|
for (int j = 0; j < 3; j++)
|
||||||
|
opts.weights[j] = atof(argv[++i]);
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "Usage: %s -[hiIntp]\n", argv[0]);
|
fprintf(stderr, "Usage: %s -[hminw]\n", argv[0]);
|
||||||
fprintf(stderr, "\t-h\tShow this message.\n");
|
fprintf(stderr, " -h\t\tShow this message.\n");
|
||||||
fprintf(stderr, "\t-m MODE\tThe disciminant to use: 'fisher' for "
|
fprintf(stderr, " -m MODE\tThe training mode to use: 'fisher' for \n\t\t"
|
||||||
"Fisher linear discriminant, 'percep' for perceptron.\n");
|
"Fisher linear discriminant, 'percep' for perceptron.\n");
|
||||||
fprintf(stderr, "\t-i N\tThe number of training iterations (for perceptron).\n");
|
fprintf(stderr, " -i N\t\tThe number of training iterations "
|
||||||
fprintf(stderr, "\t-n N\tThe number of events in noise class.\n");
|
"(for perceptron only).\n");
|
||||||
|
fprintf(stderr, " -n N\t\tThe number of events in noise class.\n");
|
||||||
|
fprintf(stderr, " -w W₁ W₂ B\tSet weights and bias (if pre-trained).\n");
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -52,9 +62,19 @@ int main(int argc, char **argv) {
|
|||||||
sample_t *noise = generate_normal(r, opts.nnoise, &par_noise);
|
sample_t *noise = generate_normal(r, opts.nnoise, &par_noise);
|
||||||
|
|
||||||
gsl_vector *w;
|
gsl_vector *w;
|
||||||
double t_cut;
|
double cut;
|
||||||
|
|
||||||
if (!strcmp(opts.mode, "fisher")) {
|
if (opts.trained) {
|
||||||
|
/* Pre-trained
|
||||||
|
*
|
||||||
|
* Use the provided weigths and bias.
|
||||||
|
*/
|
||||||
|
fputs("# Pre-trained \n\n", stderr);
|
||||||
|
gsl_vector_view v = gsl_vector_view_array(opts.weights, 2);
|
||||||
|
w = &v.vector;
|
||||||
|
cut = opts.weights[2];
|
||||||
|
}
|
||||||
|
else if (!strcmp(opts.mode, "fisher")) {
|
||||||
/* Fisher linear discriminant
|
/* Fisher linear discriminant
|
||||||
*
|
*
|
||||||
* First calculate the direction w onto
|
* First calculate the direction w onto
|
||||||
@ -65,7 +85,7 @@ int main(int argc, char **argv) {
|
|||||||
fputs("# Linear Fisher discriminant\n\n", stderr);
|
fputs("# Linear Fisher discriminant\n\n", stderr);
|
||||||
double ratio = opts.nsig / (double)opts.nnoise;
|
double ratio = opts.nsig / (double)opts.nnoise;
|
||||||
w = fisher_proj(signal, noise);
|
w = fisher_proj(signal, noise);
|
||||||
t_cut = fisher_cut(ratio, w, signal, noise);
|
cut = fisher_cut(ratio, w, signal, noise);
|
||||||
}
|
}
|
||||||
else if (!strcmp(opts.mode, "percep")) {
|
else if (!strcmp(opts.mode, "percep")) {
|
||||||
/* Perceptron
|
/* Perceptron
|
||||||
@ -75,7 +95,7 @@ int main(int argc, char **argv) {
|
|||||||
* solution in `iter` iterations.
|
* solution in `iter` iterations.
|
||||||
*/
|
*/
|
||||||
fputs("# Perceptron \n\n", stderr);
|
fputs("# Perceptron \n\n", stderr);
|
||||||
w = percep_train(signal, noise, opts.iter, &t_cut);
|
w = percep_train(signal, noise, opts.iter, &cut);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
fputs("\n\nerror: invalid mode. select either"
|
fputs("\n\nerror: invalid mode. select either"
|
||||||
@ -89,9 +109,9 @@ int main(int argc, char **argv) {
|
|||||||
fprintf(stderr, "* w: [%.3f, %.3f]\n",
|
fprintf(stderr, "* w: [%.3f, %.3f]\n",
|
||||||
gsl_vector_get(w, 0),
|
gsl_vector_get(w, 0),
|
||||||
gsl_vector_get(w, 1));
|
gsl_vector_get(w, 1));
|
||||||
fprintf(stderr, "* t_cut: %.3f\n", t_cut);
|
fprintf(stderr, "* cut: %.3f\n", cut);
|
||||||
gsl_vector_fprintf(stdout, w, "%g");
|
gsl_vector_fprintf(stdout, w, "%g");
|
||||||
printf("%f\n", t_cut);
|
printf("%f\n", cut);
|
||||||
|
|
||||||
/* Print data to stdout for plotting.
|
/* Print data to stdout for plotting.
|
||||||
* Note: we print the sizes to be able
|
* Note: we print the sizes to be able
|
||||||
|
2
makefile
2
makefile
@ -24,7 +24,7 @@ ex-5/bin/%: ex-5/%.c
|
|||||||
ex-6/bin/main: ex-6/rl.c ex-6/fft.c
|
ex-6/bin/main: ex-6/rl.c ex-6/fft.c
|
||||||
$(CCOMPILE)
|
$(CCOMPILE)
|
||||||
|
|
||||||
ex-7/bin/main: ex-7/main.c ex-7/common.c ex-7/fisher.c
|
ex-7/bin/main: ex-7/main.c ex-7/common.c ex-7/fisher.c ex-7/percep.c
|
||||||
$(CCOMPILE)
|
$(CCOMPILE)
|
||||||
|
|
||||||
misc/pdfs: misc/pdfs.c
|
misc/pdfs: misc/pdfs.c
|
||||||
|
Loading…
Reference in New Issue
Block a user