analistica/ex-7/fisher.c

187 lines
4.7 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "fisher.h"
#include <math.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_linalg.h>
/* Builds the covariance matrix Σ
* from the standard parameters (σ, ρ)
* of a bivariate gaussian.
*/
gsl_matrix* normal_cov(struct par *p) {
double var_x = pow(p->sigma_x, 2);
double var_y = pow(p->sigma_y, 2);
double cov_xy = p->rho * p->sigma_x * p->sigma_y;
gsl_matrix *cov = gsl_matrix_alloc(2, 2);
gsl_matrix_set(cov, 0, 0, var_x);
gsl_matrix_set(cov, 1, 1, var_y);
gsl_matrix_set(cov, 0, 1, cov_xy);
gsl_matrix_set(cov, 1, 0, cov_xy);
return cov;
}
/* Builds the mean vector of
* a bivariate gaussian.
*/
gsl_vector* normal_mean(struct par *p) {
gsl_vector *mu = gsl_vector_alloc(2);
gsl_vector_set(mu, 0, p->x0);
gsl_vector_set(mu, 1, p->y0);
return mu;
}
/* `fisher_proj(c1, c2)` computes the optimal
* projection map, which maximises the separation
* between the two classes.
* The projection vector w is given by
*
* w = Sw⁻¹ (μ₂ - μ₁)
*
* where Sw = Σ₁ + Σ₂ is the so-called within-class
* covariance matrix.
*/
gsl_vector* fisher_proj(sample_t *c1, sample_t *c2) {
/* Construct the covariances of each class... */
gsl_matrix *cov1 = normal_cov(&c1->p);
gsl_matrix *cov2 = normal_cov(&c2->p);
/* and the mean values */
gsl_vector *mu1 = normal_mean(&c1->p);
gsl_vector *mu2 = normal_mean(&c2->p);
/* Compute the inverse of the within-class
* covariance Sw⁻¹.
* Note: by definition Σ is symmetrical and
* positive-definite, so Cholesky is appropriate.
*/
gsl_matrix_add(cov1, cov2);
gsl_linalg_cholesky_decomp(cov1);
gsl_linalg_cholesky_invert(cov1);
/* Compute the difference of the means. */
gsl_vector *diff = gsl_vector_alloc(2);
gsl_vector_memcpy(diff, mu2);
gsl_vector_sub(diff, mu1);
/* Finally multiply diff by Sw.
* This uses the rather low-level CBLAS
* functions gsl_blas_dgemv:
*
* ___ double ___ 1 ___ nothing
* / / /
* dgemv computes y := α op(A)x + βy
* \ \__matrix-vector \____ 0
* \__ A is symmetric
*/
gsl_vector *w = gsl_vector_alloc(2);
gsl_blas_dgemv(
CblasNoTrans, // do nothing on A
1, // α = 1
cov1, // matrix A
diff, // vector x
0, // β = 0
w); // vector y
// free memory
gsl_matrix_free(cov1);
gsl_matrix_free(cov2);
gsl_vector_free(mu1);
gsl_vector_free(mu2);
gsl_vector_free(diff);
return w;
}
/* `fisher_cut(ratio, w, c1, c2)` computes
* the threshold (cut), on the line given by
* `w`, to discriminates the classes `c1`, `c2`;
* with `ratio` being the ratio of their prior
* probabilities.
*
* The cut is fixed by the condition of
* conditional probability being the
* same for each class:
*
* P(c₁|x) p(x|c₁)⋅p(c₁)
* ------- = --------------- = 1;
* P(c₂|x) p(x|c₁)⋅p(c₂)
*
* where p(x|c) is the probability for point x
* along the fisher projection line. If the classes
* are bivariate gaussian then p(x|c) is simply
* given by a normal distribution:
*
* Φ(μ=(w,μ), σ=(w,Σw))
*
* The solution is then
*
* t = (b/a) + √((b/a)² - c/a);
*
* where
*
* 1. a = S₁² - S₂²
* 2. b = M₂S₁² - M₁S₂²
* 3. c = M₂²S₁² - M₁²S₂² - 2S₁²S₂² log(α)
* 4. α = p(c₁)/p(c₂)
*
*/
double fisher_cut(
double ratio,
gsl_vector *w,
sample_t *c1, sample_t *c2) {
/* Create a temporary vector variable */
gsl_vector *vtemp = gsl_vector_alloc(w->size);
/* Construct the covariances of each class... */
gsl_matrix *cov1 = normal_cov(&c1->p);
gsl_matrix *cov2 = normal_cov(&c2->p);
/* and the mean values */
gsl_vector *mu1 = normal_mean(&c1->p);
gsl_vector *mu2 = normal_mean(&c2->p);
/* Project the distribution onto the
* w line to get a 1D gaussian
*/
/* Mean: mi = (w, μi) */
double m1; gsl_blas_ddot(w, mu1, &m1);
double m2; gsl_blas_ddot(w, mu2, &m2);
/* Variance: vari = (w, covi⋅w)
*
* vtemp = covi⋅w
* vari = w⋅vtemp
*/
gsl_blas_dgemv(CblasNoTrans, 1, cov1, w, 0, vtemp);
double var1; gsl_blas_ddot(w, vtemp, &var1);
gsl_blas_dgemv(CblasNoTrans, 1, cov2, w, 0, vtemp);
double var2; gsl_blas_ddot(w, vtemp, &var2);
/* Solve the P(c₁|x) = P(c₂|x) equation:
*
* ax² - 2bx + c = 0
*
* with a,b,c given as above.
*
* */
double a = var1 - var2;
double b = m2*var1 + m1*var2;
double c = m2*m2*var1 - m1*m1*var2 + 2*var1*var2 * log(ratio);
// free memory
gsl_vector_free(mu1);
gsl_vector_free(mu2);
gsl_vector_free(vtemp);
gsl_matrix_free(cov1);
gsl_matrix_free(cov2);
return (b/a) + sqrt(pow(b/a, 2) - c/a);
}