Classification with Fully Connected Neural Networks: A solution

MNIST classification with a fully connected model with pytorch.

In this task, you should train a fully-connected model with pytorch to classify MNIST dataset.

In [1]:
from typing import Tuple, Iterable
from tqdm import tqdm
from multiprocessing import cpu_count
from time import time

import torchvision
from torchvision import datasets, transforms
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
Out[2]:
device(type='cuda')
In [3]:
SEED = 51
torch.random.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

Load the dataset

Load the MNIST dataset bellow. You can use either torchvision.datasets.MNIST or sklearn.datasets.fetch_openml() or any other way to load the dataset.

In [4]:
mnist_root = 'mnist'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
trainset = torchvision.datasets.MNIST(mnist_root, train=True, download=True, transform=transform)
trainsize = int(len(trainset) * 0.9)
valsize = len(trainset) - trainsize
trainset, valset = torch.utils.data.random_split(trainset, [trainsize, valsize])
testset = torchvision.datasets.MNIST(mnist_root, train=False, download=True, transform=transform)

Design your model

Write your fully-connected model below using torch.nn modules. Feel free to add extra cells.

In [5]:
class Net(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Sequential(    # B 784
            nn.Linear(784, 1024),   # B 1024
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=0.2),
            nn.Linear(1024, 128),   # B 128
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=0.2),
            nn.Linear(128, 10),     # B 10
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x:                          B 1   28  28
        x = x.flatten(1)            # B 784
        x = self.fc(x)              # B 10
        return x

Train your model

Write the training process below. Instantiate your model, Create an optimizer function such as Adam or SGD, and write your train/validation loop. Then train your model until it converges. Feel free to add extra cells.

In [6]:
class Trainer:
    CHECKPOINT = 'mnist-fc.pth'

    def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer,
                 trainloader: DataLoader, valloader: DataLoader, testloader: DataLoader,
                 device: torch.device) -> None:
        self.__model = model
        self.__optimizer = optimizer
        self.__trainloader = trainloader
        self.__valloader = valloader
        self.__testloader = testloader
        self.__device = device
        self.__train_loss = []
        self.__train_accuracy = []
        self.__val_loss = []
        self.__val_accuracy = []
        self.__best_val_acc = 0

    def __forward_model(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, int]:
        x = x.to(self.__device)
        y = y.to(self.__device)
        s = self.__model(x)
        loss = F.cross_entropy(s, y)
        true = int((s.argmax(dim=-1) == y).sum())
        return loss, true

    def __train_epoch(self) -> Tuple[float, float]:
        epoch_loss = 0
        epoch_accuracy = 0
        n = 0

        self.__model.train()
        with tqdm(enumerate(self.__trainloader), total=len(self.__trainloader)) as pbar:
            for i, (x, y) in pbar:
                loss, true = self.__forward_model(x, y)
                epoch_loss += float(loss)
                epoch_accuracy += true
                n += len(x)

                loss.backward()
                self.__optimizer.step()
                self.__optimizer.zero_grad()

                pbar.set_description(f'train loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')

        epoch_loss /= len(self.__trainloader)
        epoch_accuracy *= 100.0 / n
        return epoch_loss, epoch_accuracy

    def __eval_epoch(self, dataloader: DataLoader) -> Tuple[float, float]:
        epoch_loss = 0
        epoch_accuracy = 0
        n = 0

        self.__model.eval()
        with torch.no_grad(), tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
            for i, (x, y) in pbar:
                loss, true = self.__forward_model(x, y)
                epoch_loss += float(loss)
                epoch_accuracy += true
                n += len(x)
                pbar.set_description(f'val loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')

        epoch_loss /= len(dataloader)
        epoch_accuracy *= 100.0 / n
        return epoch_loss, epoch_accuracy

    def __val_epoch(self) -> Tuple[float, float]:
        return self.__eval_epoch(self.__valloader)

    def test(self) -> Tuple[float, float]:
        return self.__eval_epoch(self.__testloader)

    def save(self) -> None:
        torch.save(self.__model.state_dict(), self.CHECKPOINT)

    def load(self) -> None:
        self.__model.load_state_dict(torch.load(self.CHECKPOINT))

    def train(self, max_epochs: int) -> None:
        if len(self.__train_loss) >= max_epochs:
            return

        for e in range(len(self.__train_loss), max_epochs):
            start_time = time()
            train_loss, train_accuracy = self.__train_epoch()
            val_loss, val_accuracy = self.__val_epoch()
            end_time = time()

            if val_accuracy > self.__best_val_acc:
                self.__best_val_acc = val_accuracy
                print(f'Saving checkpoint.')
                self.save()
            
            print(f'Epoch {e+1} finished in {end_time - start_time:.2f}s. Train [loss: {train_loss:.3e}, acc: {train_accuracy:.2f}%] - Val [loss: {val_loss:.3e}, acc: {val_accuracy:.2f}%]')

            self.__train_loss.append(train_loss)
            self.__train_accuracy.append(train_accuracy)
            self.__val_loss.append(val_loss)
            self.__val_accuracy.append(val_accuracy)

    def plot_curves(self) -> None:
        fig = plt.figure(figsize=(15, 7))

        # loss
        ax = fig.add_subplot(121)
        ax.set_title('Loss / Epoch')
        ax.set_ylabel('Loss')
        ax.set_xlabel('Epoch')
        ax.set_aspect('auto')

        plt.plot(self.__train_loss, label='Train', color='green', linewidth=3)
        plt.plot(self.__val_loss, label='Validation', color='red', linewidth=3)

        plt.legend()

        # acc
        ax = fig.add_subplot(122)
        ax.set_title('Accuracy / Epoch')
        ax.set_ylabel('Accuracy')
        ax.set_xlabel('Epoch')
        ax.set_aspect('auto')

        plt.plot(self.__train_accuracy, label='Train', color='green', linewidth=3)
        plt.plot(self.__val_accuracy, label='Validation', color='red', linewidth=3)

        plt.legend()

    def predict(self) -> Iterable[int]:
        self.__model.eval()
        with torch.no_grad(), tqdm(self.__testloader, total=len(self.__testloader)) as pbar:
            for x, _ in pbar:
                yield from self.__model(x.to(self.__device)).argmax(dim=-1)

    def draw_misclassified(self, n: int) -> None:
        p = list(self.predict())
        missed_idx = np.random.choice([i for i, (x, y) in enumerate(self.__testloader.dataset) if p[i] != y], size=n)
        missed_samples = [self.__testloader.dataset[i] for i in missed_idx]
        wrong_labels = [p[i] for i in missed_idx]
        correct_labels = [s[1] for s in missed_samples]
        images = [s[0].squeeze() for s in missed_samples]

        texts = [f'{correct} -> {missed}' for missed, correct in zip(wrong_labels, correct_labels)]

        columns = 10
        rows = int(np.ceil(n / columns))
        fig = plt.figure(figsize=(2 * columns, 2 * rows))
        for i in range(columns * rows):
            ax = fig.add_subplot(rows, columns, i + 1)
            ax.set_title(texts[i])
            ax.set_aspect('equal')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            ax.imshow(images[i], cmap='gray')

        cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
        cax.get_xaxis().set_visible(False)
        cax.get_yaxis().set_visible(False)
        cax.set_frame_on(False)

    def plot_cm(self) -> None:
        y_true = [label for _, label in self.__testloader.dataset]
        y_pred = [int(label) for label in self.predict()]
        cm = confusion_matrix(y_true, y_pred)
        sns.heatmap(cm, annot=True, fmt='d')
In [7]:
B = 512
EB = 1024
trainloader = DataLoader(trainset, batch_size=B, shuffle=True, num_workers=cpu_count())
valloader = DataLoader(valset, batch_size=EB, shuffle=False, num_workers=cpu_count())
testloader = DataLoader(testset, batch_size=EB, shuffle=False, num_workers=cpu_count())
net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
trainer = Trainer(net, optimizer, trainloader, valloader, testloader, device)
net
Out[7]:
Net(
  (fc): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=1024, out_features=128, bias=True)
    (5): ReLU()
    (6): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=128, out_features=10, bias=True)
  )
)
In [8]:
trainer.train(40)
train loss: 6.740e-01 - accuracy: 83.36%: 100%|██████████| 106/106 [00:18<00:00,  5.82it/s]
val loss: 3.225e-01 - accuracy: 93.38%: 100%|██████████| 6/6 [00:01<00:00,  3.01it/s]
Saving checkpoint.
Epoch 1 finished in 20.49s. Train [loss: 6.740e-01, acc: 83.36%] - Val [loss: 3.225e-01, acc: 93.38%]
train loss: 2.902e-01 - accuracy: 93.72%: 100%|██████████| 106/106 [00:10<00:00,  9.97it/s]
val loss: 2.109e-01 - accuracy: 95.23%: 100%|██████████| 6/6 [00:01<00:00,  5.63it/s]
Saving checkpoint.
Epoch 2 finished in 11.94s. Train [loss: 2.902e-01, acc: 93.72%] - Val [loss: 2.109e-01, acc: 95.23%]
train loss: 2.015e-01 - accuracy: 95.55%: 100%|██████████| 106/106 [00:09<00:00, 11.01it/s]
val loss: 1.564e-01 - accuracy: 96.37%: 100%|██████████| 6/6 [00:01<00:00,  5.79it/s]
Saving checkpoint.
Epoch 3 finished in 10.84s. Train [loss: 2.015e-01, acc: 95.55%] - Val [loss: 1.564e-01, acc: 96.37%]
train loss: 1.514e-01 - accuracy: 96.64%: 100%|██████████| 106/106 [00:09<00:00, 11.00it/s]
val loss: 1.250e-01 - accuracy: 96.87%: 100%|██████████| 6/6 [00:01<00:00,  5.80it/s]
Saving checkpoint.
Epoch 4 finished in 10.85s. Train [loss: 1.514e-01, acc: 96.64%] - Val [loss: 1.250e-01, acc: 96.87%]
train loss: 1.181e-01 - accuracy: 97.31%: 100%|██████████| 106/106 [00:09<00:00, 11.09it/s]
val loss: 1.050e-01 - accuracy: 97.50%: 100%|██████████| 6/6 [00:01<00:00,  5.88it/s]
Saving checkpoint.
Epoch 5 finished in 10.76s. Train [loss: 1.181e-01, acc: 97.31%] - Val [loss: 1.050e-01, acc: 97.50%]
train loss: 9.496e-02 - accuracy: 97.84%: 100%|██████████| 106/106 [00:09<00:00, 11.02it/s]
val loss: 9.136e-02 - accuracy: 97.72%: 100%|██████████| 6/6 [00:01<00:00,  5.77it/s]
Saving checkpoint.
Epoch 6 finished in 10.83s. Train [loss: 9.496e-02, acc: 97.84%] - Val [loss: 9.136e-02, acc: 97.72%]
train loss: 7.644e-02 - accuracy: 98.36%: 100%|██████████| 106/106 [00:09<00:00, 10.91it/s]
val loss: 7.992e-02 - accuracy: 98.00%: 100%|██████████| 6/6 [00:01<00:00,  5.64it/s]
Saving checkpoint.
Epoch 7 finished in 10.96s. Train [loss: 7.644e-02, acc: 98.36%] - Val [loss: 7.992e-02, acc: 98.00%]
train loss: 6.261e-02 - accuracy: 98.64%: 100%|██████████| 106/106 [00:09<00:00, 10.98it/s]
val loss: 7.301e-02 - accuracy: 98.03%: 100%|██████████| 6/6 [00:01<00:00,  5.82it/s]
Saving checkpoint.
Epoch 8 finished in 10.86s. Train [loss: 6.261e-02, acc: 98.64%] - Val [loss: 7.301e-02, acc: 98.03%]
train loss: 5.233e-02 - accuracy: 98.90%: 100%|██████████| 106/106 [00:09<00:00, 11.12it/s]
val loss: 6.871e-02 - accuracy: 98.05%: 100%|██████████| 6/6 [00:01<00:00,  5.80it/s]
Saving checkpoint.
Epoch 9 finished in 10.74s. Train [loss: 5.233e-02, acc: 98.90%] - Val [loss: 6.871e-02, acc: 98.05%]
train loss: 4.411e-02 - accuracy: 99.09%: 100%|██████████| 106/106 [00:09<00:00, 10.96it/s]
val loss: 6.368e-02 - accuracy: 98.20%: 100%|██████████| 6/6 [00:01<00:00,  5.73it/s]
Saving checkpoint.
Epoch 10 finished in 10.89s. Train [loss: 4.411e-02, acc: 99.09%] - Val [loss: 6.368e-02, acc: 98.20%]
train loss: 3.635e-02 - accuracy: 99.33%: 100%|██████████| 106/106 [00:10<00:00,  9.64it/s]
val loss: 5.869e-02 - accuracy: 98.33%: 100%|██████████| 6/6 [00:01<00:00,  5.57it/s]
Saving checkpoint.
Epoch 11 finished in 12.24s. Train [loss: 3.635e-02, acc: 99.33%] - Val [loss: 5.869e-02, acc: 98.33%]
train loss: 3.092e-02 - accuracy: 99.46%: 100%|██████████| 106/106 [00:10<00:00, 10.33it/s]
val loss: 5.783e-02 - accuracy: 98.30%: 100%|██████████| 6/6 [00:01<00:00,  5.61it/s]
Epoch 12 finished in 11.51s. Train [loss: 3.092e-02, acc: 99.46%] - Val [loss: 5.783e-02, acc: 98.30%]
train loss: 2.618e-02 - accuracy: 99.57%: 100%|██████████| 106/106 [00:09<00:00, 10.83it/s]
val loss: 5.512e-02 - accuracy: 98.37%: 100%|██████████| 6/6 [00:01<00:00,  5.63it/s]
Saving checkpoint.
Epoch 13 finished in 11.05s. Train [loss: 2.618e-02, acc: 99.57%] - Val [loss: 5.512e-02, acc: 98.37%]
train loss: 2.288e-02 - accuracy: 99.63%: 100%|██████████| 106/106 [00:09<00:00, 11.00it/s]
val loss: 5.588e-02 - accuracy: 98.40%: 100%|██████████| 6/6 [00:01<00:00,  5.81it/s]
Saving checkpoint.
Epoch 14 finished in 10.84s. Train [loss: 2.288e-02, acc: 99.63%] - Val [loss: 5.588e-02, acc: 98.40%]
train loss: 1.893e-02 - accuracy: 99.77%: 100%|██████████| 106/106 [00:09<00:00, 11.08it/s]
val loss: 5.341e-02 - accuracy: 98.42%: 100%|██████████| 6/6 [00:01<00:00,  5.66it/s]
Saving checkpoint.
Epoch 15 finished in 10.82s. Train [loss: 1.893e-02, acc: 99.77%] - Val [loss: 5.341e-02, acc: 98.42%]
train loss: 1.733e-02 - accuracy: 99.76%: 100%|██████████| 106/106 [00:09<00:00, 11.03it/s]
val loss: 5.236e-02 - accuracy: 98.42%: 100%|██████████| 6/6 [00:01<00:00,  5.60it/s]
Epoch 16 finished in 10.85s. Train [loss: 1.733e-02, acc: 99.76%] - Val [loss: 5.236e-02, acc: 98.42%]
train loss: 1.494e-02 - accuracy: 99.82%: 100%|██████████| 106/106 [00:09<00:00, 11.09it/s]
val loss: 5.278e-02 - accuracy: 98.37%: 100%|██████████| 6/6 [00:01<00:00,  5.70it/s]
Epoch 17 finished in 10.80s. Train [loss: 1.494e-02, acc: 99.82%] - Val [loss: 5.278e-02, acc: 98.37%]
train loss: 1.347e-02 - accuracy: 99.83%: 100%|██████████| 106/106 [00:09<00:00, 11.16it/s]
val loss: 5.316e-02 - accuracy: 98.35%: 100%|██████████| 6/6 [00:01<00:00,  5.77it/s]
Epoch 18 finished in 10.72s. Train [loss: 1.347e-02, acc: 99.83%] - Val [loss: 5.316e-02, acc: 98.35%]
train loss: 1.216e-02 - accuracy: 99.84%: 100%|██████████| 106/106 [00:09<00:00, 11.27it/s]
val loss: 5.003e-02 - accuracy: 98.40%: 100%|██████████| 6/6 [00:01<00:00,  5.68it/s]
Epoch 19 finished in 10.63s. Train [loss: 1.216e-02, acc: 99.84%] - Val [loss: 5.003e-02, acc: 98.40%]
train loss: 1.050e-02 - accuracy: 99.89%: 100%|██████████| 106/106 [00:09<00:00, 11.08it/s]
val loss: 5.135e-02 - accuracy: 98.40%: 100%|██████████| 6/6 [00:01<00:00,  5.71it/s]
Epoch 20 finished in 10.80s. Train [loss: 1.050e-02, acc: 99.89%] - Val [loss: 5.135e-02, acc: 98.40%]
train loss: 9.082e-03 - accuracy: 99.90%: 100%|██████████| 106/106 [00:09<00:00, 10.94it/s]
val loss: 4.998e-02 - accuracy: 98.45%: 100%|██████████| 6/6 [00:01<00:00,  5.69it/s]
Saving checkpoint.
Epoch 21 finished in 10.91s. Train [loss: 9.082e-03, acc: 99.90%] - Val [loss: 4.998e-02, acc: 98.45%]
train loss: 8.208e-03 - accuracy: 99.91%: 100%|██████████| 106/106 [00:09<00:00, 11.05it/s]
val loss: 5.170e-02 - accuracy: 98.37%: 100%|██████████| 6/6 [00:01<00:00,  5.64it/s]
Epoch 22 finished in 10.84s. Train [loss: 8.208e-03, acc: 99.91%] - Val [loss: 5.170e-02, acc: 98.37%]
train loss: 8.226e-03 - accuracy: 99.89%: 100%|██████████| 106/106 [00:09<00:00, 10.95it/s]
val loss: 5.246e-02 - accuracy: 98.37%: 100%|██████████| 6/6 [00:01<00:00,  5.57it/s]
Epoch 23 finished in 10.93s. Train [loss: 8.226e-03, acc: 99.89%] - Val [loss: 5.246e-02, acc: 98.37%]
train loss: 6.986e-03 - accuracy: 99.95%: 100%|██████████| 106/106 [00:09<00:00, 10.87it/s]
val loss: 5.178e-02 - accuracy: 98.50%: 100%|██████████| 6/6 [00:01<00:00,  5.71it/s]
Saving checkpoint.
Epoch 24 finished in 11.00s. Train [loss: 6.986e-03, acc: 99.95%] - Val [loss: 5.178e-02, acc: 98.50%]
train loss: 6.222e-03 - accuracy: 99.95%: 100%|██████████| 106/106 [00:09<00:00, 10.80it/s]
val loss: 5.205e-02 - accuracy: 98.47%: 100%|██████████| 6/6 [00:01<00:00,  5.54it/s]
Epoch 25 finished in 11.07s. Train [loss: 6.222e-03, acc: 99.95%] - Val [loss: 5.205e-02, acc: 98.47%]
train loss: 5.960e-03 - accuracy: 99.95%: 100%|██████████| 106/106 [00:09<00:00, 10.98it/s]
val loss: 5.184e-02 - accuracy: 98.48%: 100%|██████████| 6/6 [00:01<00:00,  5.52it/s]
Epoch 26 finished in 10.92s. Train [loss: 5.960e-03, acc: 99.95%] - Val [loss: 5.184e-02, acc: 98.48%]
train loss: 5.355e-03 - accuracy: 99.95%: 100%|██████████| 106/106 [00:09<00:00, 10.95it/s]
val loss: 5.145e-02 - accuracy: 98.47%: 100%|██████████| 6/6 [00:01<00:00,  5.71it/s]
Epoch 27 finished in 10.91s. Train [loss: 5.355e-03, acc: 99.95%] - Val [loss: 5.145e-02, acc: 98.47%]
train loss: 4.620e-03 - accuracy: 99.97%: 100%|██████████| 106/106 [00:09<00:00, 10.81it/s]
val loss: 5.180e-02 - accuracy: 98.55%: 100%|██████████| 6/6 [00:01<00:00,  5.55it/s]
Saving checkpoint.
Epoch 28 finished in 11.06s. Train [loss: 4.620e-03, acc: 99.97%] - Val [loss: 5.180e-02, acc: 98.55%]
train loss: 4.518e-03 - accuracy: 99.95%: 100%|██████████| 106/106 [00:10<00:00, 10.58it/s]
val loss: 5.055e-02 - accuracy: 98.52%: 100%|██████████| 6/6 [00:01<00:00,  5.43it/s]
Epoch 29 finished in 11.31s. Train [loss: 4.518e-03, acc: 99.95%] - Val [loss: 5.055e-02, acc: 98.52%]
train loss: 3.974e-03 - accuracy: 99.98%: 100%|██████████| 106/106 [00:09<00:00, 10.65it/s]
val loss: 5.067e-02 - accuracy: 98.55%: 100%|██████████| 6/6 [00:01<00:00,  5.52it/s]
Epoch 30 finished in 11.24s. Train [loss: 3.974e-03, acc: 99.98%] - Val [loss: 5.067e-02, acc: 98.55%]
train loss: 4.411e-03 - accuracy: 99.95%: 100%|██████████| 106/106 [00:09<00:00, 10.76it/s]
val loss: 5.098e-02 - accuracy: 98.53%: 100%|██████████| 6/6 [00:01<00:00,  5.59it/s]
Epoch 31 finished in 11.10s. Train [loss: 4.411e-03, acc: 99.95%] - Val [loss: 5.098e-02, acc: 98.53%]
train loss: 4.313e-03 - accuracy: 99.94%: 100%|██████████| 106/106 [00:10<00:00, 10.58it/s]
val loss: 5.203e-02 - accuracy: 98.53%: 100%|██████████| 6/6 [00:01<00:00,  5.43it/s]
Epoch 32 finished in 11.30s. Train [loss: 4.313e-03, acc: 99.94%] - Val [loss: 5.203e-02, acc: 98.53%]
train loss: 3.748e-03 - accuracy: 99.97%: 100%|██████████| 106/106 [00:09<00:00, 10.78it/s]
val loss: 5.314e-02 - accuracy: 98.42%: 100%|██████████| 6/6 [00:01<00:00,  5.72it/s]
Epoch 33 finished in 11.05s. Train [loss: 3.748e-03, acc: 99.97%] - Val [loss: 5.314e-02, acc: 98.42%]
train loss: 3.337e-03 - accuracy: 99.97%: 100%|██████████| 106/106 [00:09<00:00, 10.99it/s]
val loss: 5.140e-02 - accuracy: 98.55%: 100%|██████████| 6/6 [00:01<00:00,  5.72it/s]
Epoch 34 finished in 10.85s. Train [loss: 3.337e-03, acc: 99.97%] - Val [loss: 5.140e-02, acc: 98.55%]
train loss: 3.278e-03 - accuracy: 99.96%: 100%|██████████| 106/106 [00:09<00:00, 11.01it/s]
val loss: 5.121e-02 - accuracy: 98.57%: 100%|██████████| 6/6 [00:01<00:00,  5.58it/s]
Saving checkpoint.
Epoch 35 finished in 10.87s. Train [loss: 3.278e-03, acc: 99.96%] - Val [loss: 5.121e-02, acc: 98.57%]
train loss: 3.186e-03 - accuracy: 99.97%: 100%|██████████| 106/106 [00:09<00:00, 11.08it/s]
val loss: 5.251e-02 - accuracy: 98.52%: 100%|██████████| 6/6 [00:01<00:00,  5.64it/s]
Epoch 36 finished in 10.80s. Train [loss: 3.186e-03, acc: 99.97%] - Val [loss: 5.251e-02, acc: 98.52%]
train loss: 2.882e-03 - accuracy: 99.97%: 100%|██████████| 106/106 [00:09<00:00, 10.91it/s]
val loss: 5.167e-02 - accuracy: 98.63%: 100%|██████████| 6/6 [00:01<00:00,  5.50it/s]
Saving checkpoint.
Epoch 37 finished in 10.98s. Train [loss: 2.882e-03, acc: 99.97%] - Val [loss: 5.167e-02, acc: 98.63%]
train loss: 2.929e-03 - accuracy: 99.96%: 100%|██████████| 106/106 [00:09<00:00, 11.08it/s]
val loss: 5.785e-02 - accuracy: 98.40%: 100%|██████████| 6/6 [00:01<00:00,  5.61it/s]
Epoch 38 finished in 10.80s. Train [loss: 2.929e-03, acc: 99.96%] - Val [loss: 5.785e-02, acc: 98.40%]
train loss: 2.931e-03 - accuracy: 99.97%: 100%|██████████| 106/106 [00:09<00:00, 11.02it/s]
val loss: 5.954e-02 - accuracy: 98.37%: 100%|██████████| 6/6 [00:01<00:00,  5.89it/s]
Epoch 39 finished in 10.80s. Train [loss: 2.931e-03, acc: 99.97%] - Val [loss: 5.954e-02, acc: 98.37%]
train loss: 3.086e-03 - accuracy: 99.96%: 100%|██████████| 106/106 [00:09<00:00, 10.96it/s]
val loss: 5.616e-02 - accuracy: 98.47%: 100%|██████████| 6/6 [00:01<00:00,  5.59it/s]
Epoch 40 finished in 10.90s. Train [loss: 3.086e-03, acc: 99.96%] - Val [loss: 5.616e-02, acc: 98.47%]

Draw the training curves

Draw two diagrams for train and validat ion, one showing loss of each epoch, and another showing accuracy of each epoch.

In [9]:
trainer.plot_curves()
plt.show()

Evaluate your model

Evaluate the best epoch's model (according to the validation accuracy) on the test set, and report the accuracy. Is your model overfitted?

In [10]:
trainer.load()
_, accuracy = trainer.test()
accuracy
val loss: 5.879e-02 - accuracy: 98.32%: 100%|██████████| 10/10 [00:01<00:00,  6.08it/s]
Out[10]:
98.32000000000001

Draw misclassified images

Draw 20 misclassified images from test set with expected and predicted labels.

In [11]:
trainer.draw_misclassified(20)
plt.show()
100%|██████████| 10/10 [00:01<00:00,  5.82it/s]

Plot the confusion matrix

Plot the confusion matrix for the test set.

In [12]:
trainer.plot_cm()
plt.show()
100%|██████████| 10/10 [00:02<00:00,  4.70it/s]