ex-7: replace Fisher projection plot
It was really tricky reproduce in matplotlib but worth it.
This commit is contained in:
parent
98938fd55b
commit
3b774fb747
117
ex-7/plots/fisher.py
Normal file
117
ex-7/plots/fisher.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.transforms as trans
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
def line(ax, x, y, **args):
|
||||||
|
'''line between two points x,y'''
|
||||||
|
ax.plot([x[0], y[0]], [x[1], y[1]], **args)
|
||||||
|
|
||||||
|
|
||||||
|
def angle(v):
|
||||||
|
'''angle of a vector wrt x axis'''
|
||||||
|
return np.arctan2(v[1], v[0])
|
||||||
|
|
||||||
|
|
||||||
|
def proj_hist(a, b, w):
|
||||||
|
'''returns a file-like object of the w-projection histogram'''
|
||||||
|
fig = plt.figure(figsize=(3, 3))
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
ax.set_aspect(0.04)
|
||||||
|
ax.axis('off')
|
||||||
|
plt.hist(a @ w, color='xkcd:dark yellow', zorder=2)
|
||||||
|
plt.hist(b @ w, color='xkcd:pale purple', zorder=1)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, transparent=True)
|
||||||
|
plt.close(fig)
|
||||||
|
buf.seek(0)
|
||||||
|
return plt.imread(buf)
|
||||||
|
|
||||||
|
|
||||||
|
def place_im(ax, im, transform):
|
||||||
|
'''place an image into an Axes by applying a transformation'''
|
||||||
|
im = ax.imshow(im, interpolation='none',
|
||||||
|
extent=[-2, 4, -3, 2],
|
||||||
|
zorder=2)
|
||||||
|
trans_data = transform + ax.transData
|
||||||
|
im.set_transform(trans_data)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# rotation by π/2 matrix
|
||||||
|
rot = np.array([[0, -1], [1, 0]])
|
||||||
|
|
||||||
|
# covariance
|
||||||
|
cov = np.array([[0.33, 0.15], [0.15, 0.12]])
|
||||||
|
|
||||||
|
# means
|
||||||
|
m1 = np.array([-1.2, 0.5])
|
||||||
|
m2 = np.array([0.6, 0])
|
||||||
|
|
||||||
|
np.random.seed(3)
|
||||||
|
a = np.random.multivariate_normal(m1, cov, 100)
|
||||||
|
b = np.random.multivariate_normal(m2, cov, 100)
|
||||||
|
|
||||||
|
# w₁: naive projection onto m1-m2 plane
|
||||||
|
w1 = (m1 - m2) / np.linalg.norm(m1 - m2)
|
||||||
|
|
||||||
|
# w₂: Fisher projection
|
||||||
|
w2 = np.linalg.inv(cov) @ (m1 - m2)
|
||||||
|
w2 /= np.linalg.norm(w2)
|
||||||
|
|
||||||
|
#plt.rcParams['font.size'] = 8
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
1, 2, figsize=(6, 9),
|
||||||
|
subplot_kw=dict(aspect='equal'))
|
||||||
|
|
||||||
|
# naive projection
|
||||||
|
ax = axes[0]
|
||||||
|
ax.set_title('naïve projection', loc='right')
|
||||||
|
ax.set_xlim(-5.0, 3.5)
|
||||||
|
ax.set_ylim(-3.5, 2.0)
|
||||||
|
|
||||||
|
ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow')
|
||||||
|
ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple')
|
||||||
|
|
||||||
|
dir = rot @ w1
|
||||||
|
line(ax, m1, m2, c='xkcd:charcoal')
|
||||||
|
line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.2*dir, c='xkcd:charcoal')
|
||||||
|
line(ax, dir*2.1 - 2.0*w1, dir*2.1 + 2.6*w1, c='xkcd:charcoal')
|
||||||
|
place_im(
|
||||||
|
ax,
|
||||||
|
proj_hist(a, b, w1),
|
||||||
|
trans.Affine2D()
|
||||||
|
.rotate(angle(w1))
|
||||||
|
.translate(*dir*3.113)
|
||||||
|
.translate(*w1*(-0.7)))
|
||||||
|
|
||||||
|
# fisher projection
|
||||||
|
ax = axes[1]
|
||||||
|
ax.set_title('Fisher projection', loc='right')
|
||||||
|
ax.set_xlim(-5.0, 3.5)
|
||||||
|
ax.set_ylim(-3.5, 2.0)
|
||||||
|
|
||||||
|
ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow')
|
||||||
|
ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple')
|
||||||
|
|
||||||
|
dir = rot @ w2
|
||||||
|
line(ax, m1, m2, c='xkcd:charcoal')
|
||||||
|
line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.75*dir, c='xkcd:charcoal')
|
||||||
|
line(ax, dir*2.9 - 1.2*w2, dir*2.9 + 2.3*w2, c='xkcd:charcoal')
|
||||||
|
place_im(
|
||||||
|
ax,
|
||||||
|
proj_hist(a, b, w2),
|
||||||
|
trans.Affine2D()
|
||||||
|
.rotate(angle(w2))
|
||||||
|
.scale(0.75)
|
||||||
|
.translate(*dir*3.90)
|
||||||
|
.translate(*w2*(-0.3)))
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
#plt.savefig('notes/images/7-fisher.pdf', bbox_inches='tight')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
BIN
notes/images/7-fisher.pdf
Normal file
BIN
notes/images/7-fisher.pdf
Normal file
Binary file not shown.
Binary file not shown.
Before Width: | Height: | Size: 36 KiB |
@ -104,7 +104,7 @@ histograms resulting from the projection onto the line joining the
|
|||||||
class means: note the considerable overlap in the projected
|
class means: note the considerable overlap in the projected
|
||||||
space. The right plot shows the corresponding projection based on the
|
space. The right plot shows the corresponding projection based on the
|
||||||
Fisher linear discriminant, showing the greatly improved classes
|
Fisher linear discriminant, showing the greatly improved classes
|
||||||
separation. Figure taken from [@bishop06]](images/7-fisher.png){#fig:overlap}
|
separation.](images/7-fisher.pdf){#fig:overlap}
|
||||||
|
|
||||||
The overlap of the projections can be reduced by maximising a function that
|
The overlap of the projections can be reduced by maximising a function that
|
||||||
gives, besides a large separation, small variance within each class. The
|
gives, besides a large separation, small variance within each class. The
|
||||||
|
Loading…
Reference in New Issue
Block a user