Classification with Fully Connected Neural Networks: A solution
MNIST classification with a fully connected model with pytorch.
In this task, you should train a fully-connected model with pytorch to classify MNIST dataset.
from typing import Tuple, Iterable
from tqdm import tqdm
from multiprocessing import cpu_count
from time import time
import torchvision
from torchvision import datasets, transforms
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
SEED = 51
torch.random.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
Load the MNIST dataset bellow. You can use either torchvision.datasets.MNIST
or sklearn.datasets.fetch_openml()
or any other way to load the dataset.
mnist_root = 'mnist'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
trainset = torchvision.datasets.MNIST(mnist_root, train=True, download=True, transform=transform)
trainsize = int(len(trainset) * 0.9)
valsize = len(trainset) - trainsize
trainset, valset = torch.utils.data.random_split(trainset, [trainsize, valsize])
testset = torchvision.datasets.MNIST(mnist_root, train=False, download=True, transform=transform)
Write your fully-connected model below using torch.nn
modules. Feel free to add extra cells.
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Sequential( # B 784
nn.Linear(784, 1024), # B 1024
nn.ReLU(),
nn.BatchNorm1d(1024),
nn.Dropout(p=0.2),
nn.Linear(1024, 128), # B 128
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Dropout(p=0.2),
nn.Linear(128, 10), # B 10
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: B 1 28 28
x = x.flatten(1) # B 784
x = self.fc(x) # B 10
return x
Write the training process below. Instantiate your model, Create an optimizer function such as Adam or SGD, and write your train/validation loop. Then train your model until it converges. Feel free to add extra cells.
class Trainer:
CHECKPOINT = 'mnist-fc.pth'
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer,
trainloader: DataLoader, valloader: DataLoader, testloader: DataLoader,
device: torch.device) -> None:
self.__model = model
self.__optimizer = optimizer
self.__trainloader = trainloader
self.__valloader = valloader
self.__testloader = testloader
self.__device = device
self.__train_loss = []
self.__train_accuracy = []
self.__val_loss = []
self.__val_accuracy = []
self.__best_val_acc = 0
def __forward_model(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, int]:
x = x.to(self.__device)
y = y.to(self.__device)
s = self.__model(x)
loss = F.cross_entropy(s, y)
true = int((s.argmax(dim=-1) == y).sum())
return loss, true
def __train_epoch(self) -> Tuple[float, float]:
epoch_loss = 0
epoch_accuracy = 0
n = 0
self.__model.train()
with tqdm(enumerate(self.__trainloader), total=len(self.__trainloader)) as pbar:
for i, (x, y) in pbar:
loss, true = self.__forward_model(x, y)
epoch_loss += float(loss)
epoch_accuracy += true
n += len(x)
loss.backward()
self.__optimizer.step()
self.__optimizer.zero_grad()
pbar.set_description(f'train loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')
epoch_loss /= len(self.__trainloader)
epoch_accuracy *= 100.0 / n
return epoch_loss, epoch_accuracy
def __eval_epoch(self, dataloader: DataLoader) -> Tuple[float, float]:
epoch_loss = 0
epoch_accuracy = 0
n = 0
self.__model.eval()
with torch.no_grad(), tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
for i, (x, y) in pbar:
loss, true = self.__forward_model(x, y)
epoch_loss += float(loss)
epoch_accuracy += true
n += len(x)
pbar.set_description(f'val loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')
epoch_loss /= len(dataloader)
epoch_accuracy *= 100.0 / n
return epoch_loss, epoch_accuracy
def __val_epoch(self) -> Tuple[float, float]:
return self.__eval_epoch(self.__valloader)
def test(self) -> Tuple[float, float]:
return self.__eval_epoch(self.__testloader)
def save(self) -> None:
torch.save(self.__model.state_dict(), self.CHECKPOINT)
def load(self) -> None:
self.__model.load_state_dict(torch.load(self.CHECKPOINT))
def train(self, max_epochs: int) -> None:
if len(self.__train_loss) >= max_epochs:
return
for e in range(len(self.__train_loss), max_epochs):
start_time = time()
train_loss, train_accuracy = self.__train_epoch()
val_loss, val_accuracy = self.__val_epoch()
end_time = time()
if val_accuracy > self.__best_val_acc:
self.__best_val_acc = val_accuracy
print(f'Saving checkpoint.')
self.save()
print(f'Epoch {e+1} finished in {end_time - start_time:.2f}s. Train [loss: {train_loss:.3e}, acc: {train_accuracy:.2f}%] - Val [loss: {val_loss:.3e}, acc: {val_accuracy:.2f}%]')
self.__train_loss.append(train_loss)
self.__train_accuracy.append(train_accuracy)
self.__val_loss.append(val_loss)
self.__val_accuracy.append(val_accuracy)
def plot_curves(self) -> None:
fig = plt.figure(figsize=(15, 7))
# loss
ax = fig.add_subplot(121)
ax.set_title('Loss / Epoch')
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')
plt.plot(self.__train_loss, label='Train', color='green', linewidth=3)
plt.plot(self.__val_loss, label='Validation', color='red', linewidth=3)
plt.legend()
# acc
ax = fig.add_subplot(122)
ax.set_title('Accuracy / Epoch')
ax.set_ylabel('Accuracy')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')
plt.plot(self.__train_accuracy, label='Train', color='green', linewidth=3)
plt.plot(self.__val_accuracy, label='Validation', color='red', linewidth=3)
plt.legend()
def predict(self) -> Iterable[int]:
self.__model.eval()
with torch.no_grad(), tqdm(self.__testloader, total=len(self.__testloader)) as pbar:
for x, _ in pbar:
yield from self.__model(x.to(self.__device)).argmax(dim=-1)
def draw_misclassified(self, n: int) -> None:
p = list(self.predict())
missed_idx = np.random.choice([i for i, (x, y) in enumerate(self.__testloader.dataset) if p[i] != y], size=n)
missed_samples = [self.__testloader.dataset[i] for i in missed_idx]
wrong_labels = [p[i] for i in missed_idx]
correct_labels = [s[1] for s in missed_samples]
images = [s[0].squeeze() for s in missed_samples]
texts = [f'{correct} -> {missed}' for missed, correct in zip(wrong_labels, correct_labels)]
columns = 10
rows = int(np.ceil(n / columns))
fig = plt.figure(figsize=(2 * columns, 2 * rows))
for i in range(columns * rows):
ax = fig.add_subplot(rows, columns, i + 1)
ax.set_title(texts[i])
ax.set_aspect('equal')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow(images[i], cmap='gray')
cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
cax.get_xaxis().set_visible(False)
cax.get_yaxis().set_visible(False)
cax.set_frame_on(False)
def plot_cm(self) -> None:
y_true = [label for _, label in self.__testloader.dataset]
y_pred = [int(label) for label in self.predict()]
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
B = 512
EB = 1024
trainloader = DataLoader(trainset, batch_size=B, shuffle=True, num_workers=cpu_count())
valloader = DataLoader(valset, batch_size=EB, shuffle=False, num_workers=cpu_count())
testloader = DataLoader(testset, batch_size=EB, shuffle=False, num_workers=cpu_count())
net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
trainer = Trainer(net, optimizer, trainloader, valloader, testloader, device)
net
trainer.train(40)
Draw two diagrams for train and validat ion, one showing loss of each epoch, and another showing accuracy of each epoch.
trainer.plot_curves()
plt.show()
Evaluate the best epoch's model (according to the validation accuracy) on the test set, and report the accuracy. Is your model overfitted?
trainer.load()
_, accuracy = trainer.test()
accuracy
Draw 20 misclassified images from test set with expected and predicted labels.
trainer.draw_misclassified(20)
plt.show()
Plot the confusion matrix for the test set.
trainer.plot_cm()
plt.show()