analistica/ex-7/plots/fisher.py
rnhmjoj 3b774fb747 ex-7: replace Fisher projection plot
It was really tricky reproduce in matplotlib but worth it.
2020-07-05 11:36:30 +02:00

118 lines
3.1 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as trans
import io
def line(ax, x, y, **args):
'''line between two points x,y'''
ax.plot([x[0], y[0]], [x[1], y[1]], **args)
def angle(v):
'''angle of a vector wrt x axis'''
return np.arctan2(v[1], v[0])
def proj_hist(a, b, w):
'''returns a file-like object of the w-projection histogram'''
fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111)
ax.set_aspect(0.04)
ax.axis('off')
plt.hist(a @ w, color='xkcd:dark yellow', zorder=2)
plt.hist(b @ w, color='xkcd:pale purple', zorder=1)
buf = io.BytesIO()
fig.savefig(buf, transparent=True)
plt.close(fig)
buf.seek(0)
return plt.imread(buf)
def place_im(ax, im, transform):
'''place an image into an Axes by applying a transformation'''
im = ax.imshow(im, interpolation='none',
extent=[-2, 4, -3, 2],
zorder=2)
trans_data = transform + ax.transData
im.set_transform(trans_data)
def main():
# rotation by π/2 matrix
rot = np.array([[0, -1], [1, 0]])
# covariance
cov = np.array([[0.33, 0.15], [0.15, 0.12]])
# means
m1 = np.array([-1.2, 0.5])
m2 = np.array([0.6, 0])
np.random.seed(3)
a = np.random.multivariate_normal(m1, cov, 100)
b = np.random.multivariate_normal(m2, cov, 100)
# w₁: naive projection onto m1-m2 plane
w1 = (m1 - m2) / np.linalg.norm(m1 - m2)
# w₂: Fisher projection
w2 = np.linalg.inv(cov) @ (m1 - m2)
w2 /= np.linalg.norm(w2)
#plt.rcParams['font.size'] = 8
fig, axes = plt.subplots(
1, 2, figsize=(6, 9),
subplot_kw=dict(aspect='equal'))
# naive projection
ax = axes[0]
ax.set_title('naïve projection', loc='right')
ax.set_xlim(-5.0, 3.5)
ax.set_ylim(-3.5, 2.0)
ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow')
ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple')
dir = rot @ w1
line(ax, m1, m2, c='xkcd:charcoal')
line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.2*dir, c='xkcd:charcoal')
line(ax, dir*2.1 - 2.0*w1, dir*2.1 + 2.6*w1, c='xkcd:charcoal')
place_im(
ax,
proj_hist(a, b, w1),
trans.Affine2D()
.rotate(angle(w1))
.translate(*dir*3.113)
.translate(*w1*(-0.7)))
# fisher projection
ax = axes[1]
ax.set_title('Fisher projection', loc='right')
ax.set_xlim(-5.0, 3.5)
ax.set_ylim(-3.5, 2.0)
ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow')
ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple')
dir = rot @ w2
line(ax, m1, m2, c='xkcd:charcoal')
line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.75*dir, c='xkcd:charcoal')
line(ax, dir*2.9 - 1.2*w2, dir*2.9 + 2.3*w2, c='xkcd:charcoal')
place_im(
ax,
proj_hist(a, b, w2),
trans.Affine2D()
.rotate(angle(w2))
.scale(0.75)
.translate(*dir*3.90)
.translate(*w2*(-0.3)))
plt.tight_layout()
#plt.savefig('notes/images/7-fisher.pdf', bbox_inches='tight')
plt.show()
if __name__ == '__main__':
main()