analistica/ex-7/fisher.h

67 lines
1.6 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "common.h"
#include <gsl/gsl_matrix.h>
/* Builds the covariance matrix Σ
* from the standard parameters (σ, ρ)
* of a bivariate gaussian.
*/
gsl_matrix* normal_cov(struct par *p);
/* Builds the mean vector of
* a bivariate gaussian.
*/
gsl_vector* normal_mean(struct par *p);
/* `fisher_proj(c1, c2)` computes the optimal
* projection map, which maximises the separation
* between the two classes.
* The projection vector w is given by
*
* w = Sw⁻¹ (μ₂ - μ₁)
*
* where Sw = Σ₁ + Σ₂ is the so-called within-class
* covariance matrix.
*/
gsl_vector* fisher_proj(sample_t *c1, sample_t *c2);
/* `fisher_cut(ratio, w, c1, c2)` computes
* the threshold (cut), on the line given by
* `w`, to discriminates the classes `c1`, `c2`;
* with `ratio` being the ratio of their prior
* probabilities.
*
* The cut is fixed by the condition of
* conditional probability being the
* same for each class:
*
* P(c₁|x) p(x|c₁)⋅p(c₁)
* ------- = --------------- = 1;
* P(c₂|x) p(x|c₁)⋅p(c₂)
*
* where p(x|c) is the probability for point x
* along the fisher projection line. If the classes
* are bivariate gaussian then p(x|c) is simply
* given by a normal distribution:
*
* Φ(μ=(w,μ), σ=(w,Σw))
*
* The solution is then
*
* t = (b/a) + √((b/a)² - c/a);
*
* where
*
* 1. a = S₁² - S₂²
* 2. b = M₂S₁² - M₁S₂²
* 3. c = M₂²S₁² - M₁²S₂² - 2S₁²S₂² log(α)
* 4. α = p(c₁)/p(c₂)
*
*/
double fisher_cut(
double ratio,
gsl_vector *w,
sample_t *c1, sample_t *c2);