Clustering using pretrained deep models: A solution

MNIST clustering using pretrained ResNet 18 and KMeans.

In this task, you are going to cluster MNIST dataset using a pretrained ResNet 18 model and k-means.

In [1]:
from typing import Tuple, Iterable, Union, List
from tqdm import tqdm
from multiprocessing import cpu_count, Pool
from time import time

import torchvision
from torchvision import datasets, transforms, models
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans
In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
Out[2]:
device(type='cuda')
In [3]:
SEED = 51
torch.random.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

Load the dataset

Load the MNIST dataset bellow. You can use either torchvision.datasets.MNIST or sklearn.datasets.fetch_openml() or any other way to load the dataset. Note that you won't need a validation set.

In [4]:
class RepeatInterleave(nn.Module):

    def __init__(self, repeats: Union[torch.Tensor, int], dim: int = None):
        super().__init__()
        self.__repeats = repeats
        self.__dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.repeat_interleave(self.__repeats, dim=self.__dim)
In [5]:
mnist_root = 'mnist'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    RepeatInterleave(3, dim=-3),
    transforms.Normalize((0.1307,), (0.3081,)),
])
trainset = datasets.MNIST(mnist_root, train=True, download=True, transform=transform)
testset = datasets.MNIST(mnist_root, train=False, download=True, transform=transform)

Get ResNet 18

Instantiate ResNet 18 model (pretrained on imagenet) from torchvision's model zoo.

In [6]:
net = models.resnet18(pretrained=True).to(device)
net
Out[6]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

Remove the decider layer

Replace the fc layer of the resnet with an Identity layer so that we use the hidden features of the last layer.

In [7]:
net.fc = nn.Identity().to(device)

Cluster the data

Use MiniBatchKMeans to cluster the MNIST dataset in 64 clusters.

In [8]:
B = 1024
trainloader = DataLoader(trainset, batch_size=B, shuffle=True)
testloader = DataLoader(testset, batch_size=B, shuffle=False)
n_clusters = 64
kmeans = MiniBatchKMeans(
    n_clusters=n_clusters,
    batch_size=B,
    random_state=SEED,
    verbose=1,
)
In [9]:
with torch.no_grad():
    net.eval()
    for x, _ in tqdm(trainloader, total=len(trainloader)):
        x: torch.Tensor = x.to(device)
        z: np.ndarray = net(x).cpu().numpy()
        kmeans.partial_fit(z)
100%|██████████| 59/59 [04:28<00:00,  4.55s/it]

Assign clusters to test set.

Predict cluster for test set samples.

In [10]:
ground_truth = []
predictions = []
with torch.no_grad():
    net.eval()
    for x, y in tqdm(testloader, total=len(testloader)):
        x: torch.Tensor = x.to(device)
        z: np.ndarray = net(x).cpu().numpy()
        ground_truth.append(y.numpy())
        predictions.append(kmeans.predict(z))
ground_truth = np.concatenate(ground_truth)
predictions = np.concatenate(predictions)
100%|██████████| 10/10 [00:45<00:00,  4.53s/it]

Draw clusters

Draw 10 random samples per each cluster from the test set.

In [11]:
def select_images(cluster: int) -> Tuple[List[np.ndarray], np.ndarray]:
    indices = np.argwhere(predictions == cluster)[:, 0]
    chosen = np.random.choice(indices, size=image_per_cluster, replace=False)
    labels = ground_truth[chosen]
    images = [testset[i][0][0].numpy() for i in chosen]
    return images, labels
In [12]:
image_per_cluster = 10
with Pool(cpu_count()) as pool:
    selected = pool.map(select_images, range(n_clusters))
In [13]:
fig, axes = plt.subplots(nrows=10, ncols=n_clusters, figsize=(n_clusters * 2, 20))
for ir, row in enumerate(axes):
    for cluster, ax in enumerate(row):
        image = selected[cluster][0][ir]
        image = (image - image.min()) / (image.max() - image.min())
        label = selected[cluster][1][ir]
        ax.set_title(f'c[{cluster}] - l[{label}]')
        ax.set_aspect('equal')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.imshow(image, cmap='gray')
plt.show()
In [13]: