Clustering using pretrained deep models: A solution
MNIST clustering using pretrained ResNet 18 and KMeans.
In this task, you are going to cluster MNIST dataset using a pretrained ResNet 18 model and k-means.
from typing import Tuple, Iterable, Union, List
from tqdm import tqdm
from multiprocessing import cpu_count, Pool
from time import time
import torchvision
from torchvision import datasets, transforms, models
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
SEED = 51
torch.random.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
Load the MNIST dataset bellow. You can use either torchvision.datasets.MNIST
or sklearn.datasets.fetch_openml()
or any other way to load the dataset. Note that you won't need a validation set.
class RepeatInterleave(nn.Module):
def __init__(self, repeats: Union[torch.Tensor, int], dim: int = None):
super().__init__()
self.__repeats = repeats
self.__dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.repeat_interleave(self.__repeats, dim=self.__dim)
mnist_root = 'mnist'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
RepeatInterleave(3, dim=-3),
transforms.Normalize((0.1307,), (0.3081,)),
])
trainset = datasets.MNIST(mnist_root, train=True, download=True, transform=transform)
testset = datasets.MNIST(mnist_root, train=False, download=True, transform=transform)
Instantiate ResNet 18 model (pretrained on imagenet) from torchvision
's model zoo.
net = models.resnet18(pretrained=True).to(device)
net
Replace the fc
layer of the resnet with an Identity
layer so that we use the hidden features of the last layer.
net.fc = nn.Identity().to(device)
Use MiniBatchKMeans
to cluster the MNIST dataset in 64 clusters.
B = 1024
trainloader = DataLoader(trainset, batch_size=B, shuffle=True)
testloader = DataLoader(testset, batch_size=B, shuffle=False)
n_clusters = 64
kmeans = MiniBatchKMeans(
n_clusters=n_clusters,
batch_size=B,
random_state=SEED,
verbose=1,
)
with torch.no_grad():
net.eval()
for x, _ in tqdm(trainloader, total=len(trainloader)):
x: torch.Tensor = x.to(device)
z: np.ndarray = net(x).cpu().numpy()
kmeans.partial_fit(z)
Predict cluster for test set samples.
ground_truth = []
predictions = []
with torch.no_grad():
net.eval()
for x, y in tqdm(testloader, total=len(testloader)):
x: torch.Tensor = x.to(device)
z: np.ndarray = net(x).cpu().numpy()
ground_truth.append(y.numpy())
predictions.append(kmeans.predict(z))
ground_truth = np.concatenate(ground_truth)
predictions = np.concatenate(predictions)
Draw 10 random samples per each cluster from the test set.
def select_images(cluster: int) -> Tuple[List[np.ndarray], np.ndarray]:
indices = np.argwhere(predictions == cluster)[:, 0]
chosen = np.random.choice(indices, size=image_per_cluster, replace=False)
labels = ground_truth[chosen]
images = [testset[i][0][0].numpy() for i in chosen]
return images, labels
image_per_cluster = 10
with Pool(cpu_count()) as pool:
selected = pool.map(select_images, range(n_clusters))
fig, axes = plt.subplots(nrows=10, ncols=n_clusters, figsize=(n_clusters * 2, 20))
for ir, row in enumerate(axes):
for cluster, ax in enumerate(row):
image = selected[cluster][0][ir]
image = (image - image.min()) / (image.max() - image.min())
label = selected[cluster][1][ir]
ax.set_title(f'c[{cluster}] - l[{label}]')
ax.set_aspect('equal')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow(image, cmap='gray')
plt.show()