118 lines
3.1 KiB
Python
118 lines
3.1 KiB
Python
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.transforms as trans
|
||
|
import io
|
||
|
|
||
|
|
||
|
def line(ax, x, y, **args):
|
||
|
'''line between two points x,y'''
|
||
|
ax.plot([x[0], y[0]], [x[1], y[1]], **args)
|
||
|
|
||
|
|
||
|
def angle(v):
|
||
|
'''angle of a vector wrt x axis'''
|
||
|
return np.arctan2(v[1], v[0])
|
||
|
|
||
|
|
||
|
def proj_hist(a, b, w):
|
||
|
'''returns a file-like object of the w-projection histogram'''
|
||
|
fig = plt.figure(figsize=(3, 3))
|
||
|
ax = fig.add_subplot(111)
|
||
|
ax.set_aspect(0.04)
|
||
|
ax.axis('off')
|
||
|
plt.hist(a @ w, color='xkcd:dark yellow', zorder=2)
|
||
|
plt.hist(b @ w, color='xkcd:pale purple', zorder=1)
|
||
|
buf = io.BytesIO()
|
||
|
fig.savefig(buf, transparent=True)
|
||
|
plt.close(fig)
|
||
|
buf.seek(0)
|
||
|
return plt.imread(buf)
|
||
|
|
||
|
|
||
|
def place_im(ax, im, transform):
|
||
|
'''place an image into an Axes by applying a transformation'''
|
||
|
im = ax.imshow(im, interpolation='none',
|
||
|
extent=[-2, 4, -3, 2],
|
||
|
zorder=2)
|
||
|
trans_data = transform + ax.transData
|
||
|
im.set_transform(trans_data)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# rotation by π/2 matrix
|
||
|
rot = np.array([[0, -1], [1, 0]])
|
||
|
|
||
|
# covariance
|
||
|
cov = np.array([[0.33, 0.15], [0.15, 0.12]])
|
||
|
|
||
|
# means
|
||
|
m1 = np.array([-1.2, 0.5])
|
||
|
m2 = np.array([0.6, 0])
|
||
|
|
||
|
np.random.seed(3)
|
||
|
a = np.random.multivariate_normal(m1, cov, 100)
|
||
|
b = np.random.multivariate_normal(m2, cov, 100)
|
||
|
|
||
|
# w₁: naive projection onto m1-m2 plane
|
||
|
w1 = (m1 - m2) / np.linalg.norm(m1 - m2)
|
||
|
|
||
|
# w₂: Fisher projection
|
||
|
w2 = np.linalg.inv(cov) @ (m1 - m2)
|
||
|
w2 /= np.linalg.norm(w2)
|
||
|
|
||
|
#plt.rcParams['font.size'] = 8
|
||
|
fig, axes = plt.subplots(
|
||
|
1, 2, figsize=(6, 9),
|
||
|
subplot_kw=dict(aspect='equal'))
|
||
|
|
||
|
# naive projection
|
||
|
ax = axes[0]
|
||
|
ax.set_title('naïve projection', loc='right')
|
||
|
ax.set_xlim(-5.0, 3.5)
|
||
|
ax.set_ylim(-3.5, 2.0)
|
||
|
|
||
|
ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow')
|
||
|
ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple')
|
||
|
|
||
|
dir = rot @ w1
|
||
|
line(ax, m1, m2, c='xkcd:charcoal')
|
||
|
line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.2*dir, c='xkcd:charcoal')
|
||
|
line(ax, dir*2.1 - 2.0*w1, dir*2.1 + 2.6*w1, c='xkcd:charcoal')
|
||
|
place_im(
|
||
|
ax,
|
||
|
proj_hist(a, b, w1),
|
||
|
trans.Affine2D()
|
||
|
.rotate(angle(w1))
|
||
|
.translate(*dir*3.113)
|
||
|
.translate(*w1*(-0.7)))
|
||
|
|
||
|
# fisher projection
|
||
|
ax = axes[1]
|
||
|
ax.set_title('Fisher projection', loc='right')
|
||
|
ax.set_xlim(-5.0, 3.5)
|
||
|
ax.set_ylim(-3.5, 2.0)
|
||
|
|
||
|
ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow')
|
||
|
ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple')
|
||
|
|
||
|
dir = rot @ w2
|
||
|
line(ax, m1, m2, c='xkcd:charcoal')
|
||
|
line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.75*dir, c='xkcd:charcoal')
|
||
|
line(ax, dir*2.9 - 1.2*w2, dir*2.9 + 2.3*w2, c='xkcd:charcoal')
|
||
|
place_im(
|
||
|
ax,
|
||
|
proj_hist(a, b, w2),
|
||
|
trans.Affine2D()
|
||
|
.rotate(angle(w2))
|
||
|
.scale(0.75)
|
||
|
.translate(*dir*3.90)
|
||
|
.translate(*w2*(-0.3)))
|
||
|
|
||
|
plt.tight_layout()
|
||
|
#plt.savefig('notes/images/7-fisher.pdf', bbox_inches='tight')
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|