import numpy as np import matplotlib.pyplot as plt import matplotlib.transforms as trans import io def line(ax, x, y, **args): '''line between two points x,y''' ax.plot([x[0], y[0]], [x[1], y[1]], **args) def angle(v): '''angle of a vector wrt x axis''' return np.arctan2(v[1], v[0]) def proj_hist(a, b, w): '''returns a file-like object of the w-projection histogram''' fig = plt.figure(figsize=(3, 3)) ax = fig.add_subplot(111) ax.set_aspect(0.04) ax.axis('off') plt.hist(a @ w, color='xkcd:dark yellow', zorder=2) plt.hist(b @ w, color='xkcd:pale purple', zorder=1) buf = io.BytesIO() fig.savefig(buf, transparent=True) plt.close(fig) buf.seek(0) return plt.imread(buf) def place_im(ax, im, transform): '''place an image into an Axes by applying a transformation''' im = ax.imshow(im, interpolation='none', extent=[-2, 4, -3, 2], zorder=2) trans_data = transform + ax.transData im.set_transform(trans_data) def main(): # rotation by π/2 matrix rot = np.array([[0, -1], [1, 0]]) # covariance cov = np.array([[0.33, 0.15], [0.15, 0.12]]) # means m1 = np.array([-1.2, 0.5]) m2 = np.array([0.6, 0]) np.random.seed(3) a = np.random.multivariate_normal(m1, cov, 100) b = np.random.multivariate_normal(m2, cov, 100) # w₁: naive projection onto m1-m2 plane w1 = (m1 - m2) / np.linalg.norm(m1 - m2) # w₂: Fisher projection w2 = np.linalg.inv(cov) @ (m1 - m2) w2 /= np.linalg.norm(w2) #plt.rcParams['font.size'] = 8 fig, axes = plt.subplots( 1, 2, figsize=(6, 9), subplot_kw=dict(aspect='equal')) # naive projection ax = axes[0] ax.set_title('naïve projection', loc='right') ax.set_xlim(-5.0, 3.5) ax.set_ylim(-3.5, 2.0) ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow') ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple') dir = rot @ w1 line(ax, m1, m2, c='xkcd:charcoal') line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.2*dir, c='xkcd:charcoal') line(ax, dir*2.1 - 2.0*w1, dir*2.1 + 2.6*w1, c='xkcd:charcoal') place_im( ax, proj_hist(a, b, w1), trans.Affine2D() .rotate(angle(w1)) .translate(*dir*3.113) .translate(*w1*(-0.7))) # fisher projection ax = axes[1] ax.set_title('Fisher projection', loc='right') ax.set_xlim(-5.0, 3.5) ax.set_ylim(-3.5, 2.0) ax.scatter(*a.T, edgecolor='xkcd:slate grey', c='xkcd:dark yellow') ax.scatter(*b.T, edgecolor='xkcd:slate grey', c='xkcd:pale purple') dir = rot @ w2 line(ax, m1, m2, c='xkcd:charcoal') line(ax, (m1 + m2)/2, (m1 + m2)/2 + 2.75*dir, c='xkcd:charcoal') line(ax, dir*2.9 - 1.2*w2, dir*2.9 + 2.3*w2, c='xkcd:charcoal') place_im( ax, proj_hist(a, b, w2), trans.Affine2D() .rotate(angle(w2)) .scale(0.75) .translate(*dir*3.90) .translate(*w2*(-0.3))) plt.tight_layout() #plt.savefig('notes/images/7-fisher.pdf', bbox_inches='tight') plt.show() if __name__ == '__main__': main()