Visualizing the space of handwritten digits using masspcf#

This notebook serves as an introduction to using the ‘masspcf’ Python package.

We will create point cloud data from the MNIST dataset of handwritten digits. From these data we will extract topological invariants (piecewise constant functions: PCFs) and store in a matrix. Using this matrix, we compute pairwise distances between instances in our dataset and finally cluster these using a t-SNE embedding.

Install and activate prerequisites#

[ ]:
# For better plotting support in notebook
!pip install ipympl
!pip install tqdm
# For creation of sample dataset
!pip install scikit-learn
!pip install pandas
# For barcode computation
!pip install ripser
[2]:
import numpy as np
import matplotlib.pyplot as plt
import ipympl
%matplotlib widget
from tqdm import trange, tqdm

Install masspcf#

For details, see https://masspcf.readthedocs.io

[ ]:
# Install masspcf-cpu or, alternatively, masspcf if one or more CUDA GPUs are available
!pip install masspcf-cpu

Build example dataset#

We will use the MNIST dataset of 28x28 grayscale images of handwritten digits:

[1] LeCun, Y., Cortes, C., & Burges, CJ. (1998). The MNIST database of handwritten digits

[3]:
from sklearn.datasets import fetch_openml
Xin, yin = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
Xin = Xin.reshape(Xin.shape[0], 28, 28)
print(Xin.shape)
(70000, 28, 28)
[4]:
use_digits=[1, 8]
mask = np.logical_or.reduce([yin == str(i) for i in use_digits])
X = Xin[mask]
y = yin[mask]
print(X.shape)
(14702, 28, 28)
[5]:
fig, axs = plt.subplots(1, 6, figsize=(10,3), tight_layout=True)
for i, ax in enumerate(axs):
    ax.imshow(X[i+2],cmap='gray', vmin=0, vmax=255)

Threshold and sample#

[6]:
threshold = 150
sample_size = 100
noise = 0.05

np.random.seed(0)

Xhi = (X > threshold)

Xcloud = np.zeros((X.shape[0], sample_size, 2))

for idx in trange(X.shape[0]):
    sp = np.linspace(1.0, 28.0, 28)

    Hsp, Vsp = np.meshgrid(sp, sp[::-1])

    Hpts = Hsp[Xhi[idx,:,:]]
    Vpts = Vsp[Xhi[idx,:,:]]

    pts = (np.vstack((Hpts, Vpts))).T

    replace = (pts.shape[0] < sample_size)
    sample_idxs = np.random.choice(range(pts.shape[0]), size=sample_size, replace=replace)

    pts = pts[sample_idxs, :] + noise * np.random.randn(sample_size, 2)

    Xcloud[idx, :, :] = pts

fig, axs = plt.subplots(1, 6, figsize=(10,3), tight_layout=True)
for i, ax in enumerate(axs):
    ax.scatter(Xcloud[i,:,0], Xcloud[i,:,1], s=2.0)
100%|█████████████████████████████████████████████████████████████████████████| 14702/14702 [00:00<00:00, 23965.46it/s]

Compute invariants (stable rank)#

[7]:
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display
from persim import plot_diagrams
from ripser import Rips

def plotInteractive():
    fig, axs = plt.subplots(1, 3, figsize=(10,3.1), width_ratios=[1,1,1])

    def update(idx=1):
        rips = Rips(verbose=False)
        diagrams = rips.fit_transform(Xcloud[idx, :, :])

        axs[0].cla()
        axs[1].cla()
        axs[2].cla()

        axs[0].imshow(X[idx,:,:], cmap='gray', vmin=0, vmax=255)
        axs[1].scatter(Xcloud[idx, :, 0], Xcloud[idx, :, 1])
        axs[1].set_xlim([0, 28])
        axs[1].set_ylim([0, 28])
        plot_diagrams(diagrams, ax=axs[2])


        fig.canvas.draw()

    interact(update, idx=(0,10,1))

plotInteractive()

Stable rank#

We are computing a special case of the stable rank invariant. For more, see:

[2] Chachólski, W., & Riihimäki, H. (2020). Metrics and stabilization in one parameter persistence. SIAM J. Appl. Algebra Geom., 4(1), 69-98. [3] Gäfvert, O., & Chachólski, W. (2017). Stable invariants for multiparameter persistence. arXiv:1703.03632. [4] Scolamiero, M., Chachólski, W., Lundman, A., Ramanujam, R., & Öberg, S. (2017). Multidimensional persistence and noise. Found. Comput. Math., 17, 1367-1406.

~

In short, let \(K(b,d)\) denote an interval module (bar) starting at \(b \in [0, \infty)\) and ending at \(d\).

If \(\bigoplus_{i=1}^n K(b_i,d_i)\) is a barcode, then \(\widehat{\mathrm{srank}}(\bigoplus_{i=1}^n K(b_i,d_i)) \colon [0, \infty) \to \mathbb{R}\) can be explicitly computed as \(t \mapsto |\{i : d_i - b_i > t \}|\).

(Note that stable rank exists in more general settings.)

[8]:
def compute_stable_rank(diagram):
    lifetimes = diagram[:,1] - diagram[:,0]
    lifetimes = np.sort(lifetimes)
    nodes, counts = np.unique(lifetimes, return_counts=True)

    sums = np.cumsum(counts[::-1])[::-1]

    times = np.insert(nodes, 0, 0)
    values = np.append(sums, 0)

    if times[-1] == np.inf:
        times = times[:-1]
        values = values[:-1]

    return np.vstack((times, values)).T
[9]:
rips = Rips(verbose=False)
idx = 1
diagrams = rips.fit_transform(Xcloud[idx, :, :])
h1 = compute_stable_rank(diagrams[1])
print(h1)
[[ 0.         23.        ]
 [ 0.22718811 22.        ]
 [ 0.22784448 21.        ]
 [ 0.23986614 20.        ]
 [ 0.24283421 19.        ]
 [ 0.27310181 18.        ]
 [ 0.28199542 17.        ]
 [ 0.28832221 16.        ]
 [ 0.29036283 15.        ]
 [ 0.29466891 14.        ]
 [ 0.30546284 13.        ]
 [ 0.30891371 12.        ]
 [ 0.32150459 11.        ]
 [ 0.3248378  10.        ]
 [ 0.33242011  9.        ]
 [ 0.33552694  8.        ]
 [ 0.34005415  7.        ]
 [ 0.3513993   6.        ]
 [ 0.35366344  5.        ]
 [ 0.35925007  4.        ]
 [ 0.36968929  3.        ]
 [ 0.36982512  2.        ]
 [ 0.40468955  1.        ]
 [ 0.55580544  0.        ]]
[10]:
import masspcf as mpcf

h1sr = mpcf.Pcf(h1)
[11]:
from masspcf.plotting import plot as pcfplot

fig, axs = plt.subplots(1, 2, figsize=(6,3), constrained_layout=True)

plot_diagrams(diagrams, ax=axs[0], lifetime=True)
pcfplot(h1sr, ax=axs[1], color='tab:orange')

PCF arrays#

[12]:
Z = mpcf.zeros((10, 3, 5))
print(Z.shape)
Shape(10, 3, 5)
[13]:
Xsmall = Xcloud[0:5000, :, :]
pcfs = mpcf.zeros((Xsmall.shape[0], 2)) # col 0 = H0, col 1 = H1

for idx, pts in enumerate(tqdm(Xsmall)):
    diagrams = rips.fit_transform(pts)

    srdata_h0 = compute_stable_rank(diagrams[0])
    srdata_h1 = compute_stable_rank(diagrams[1])

    pcfs[idx, 0] = mpcf.Pcf(srdata_h0)
    pcfs[idx, 1] = mpcf.Pcf(srdata_h1)
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 549.39it/s]
[14]:
fig, ax = plt.subplots()
pcfplot(pcfs[10:14,0])

Combining features#

[15]:
avgs = mpcf.mean(pcfs, dim=1)
print(avgs.shape)
Shape(5000)
[16]:
avfig, avax = plt.subplots()

pcfplot(pcfs[0,0], ax=avax, label='H0')
pcfplot(pcfs[0,1], ax=avax, label='H1')
pcfplot(avgs[0], ax=avax, label='Avg')

plt.legend()
[16]:
<matplotlib.legend.Legend at 0x1fcf7590590>

Clustering#

Compute distance matrix#

[17]:
D = mpcf.pdist(avgs)
print(D)
print(D.shape)
[[ 0.         4.0011654  9.072399  ... 27.778965   9.305885   4.64795  ]
 [ 4.0011654  0.        11.520903  ... 25.717815  11.917289   1.708406 ]
 [ 9.072399  11.520903   0.        ... 36.85134    1.0296581 12.579623 ]
 ...
 [27.778965  25.717815  36.85134   ...  0.        37.084835  24.385948 ]
 [ 9.305885  11.917289   1.0296581 ... 37.084835   0.        13.178361 ]
 [ 4.64795    1.708406  12.579623  ... 24.385948  13.178361   0.       ]]
(5000, 5000)
[18]:
fig, ax = plt.subplots()
ax.matshow(D[0:500, 0:500])
[18]:
<matplotlib.image.AxesImage at 0x1fcf75add90>

t-SNE visualization#

We will visualize our 784-dimensional data using t-SNE [5] using distances from persistent homology (computed above).

[5] Van der Maaten, L., & Hinton, G. (2008). Visualizing data using t-SNE. Journal of machine learning research, 9(11).

[19]:
from sklearn.manifold import TSNE

X_embedded = TSNE(n_components=3, metric='precomputed', init='random').fit_transform(D)
[20]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

X_embedded_grouped = [ (i, X_embedded[y[:Xsmall.shape[0]] == str(i)]) for i in use_digits ]
colors = {use_digits[0]: 'red', use_digits[1]: 'blue'}

for i, grp in X_embedded_grouped:
    ax.scatter(grp[:,0], grp[:,1], grp[:,2], c=colors[i], label=str(i), s=0.6)

ax.legend()
[20]:
<matplotlib.legend.Legend at 0x1fcf79b2390>
[ ]: