Classification with Convolutional Neural Networks: A solution
CIFAR10 classification with a convolutional model with pytorch.
In this task, you should train a CNN model with pytorch to classify CIFAR10 dataset.
from typing import Tuple, Iterable
from tqdm import tqdm
from time import time
import torchvision
from torchvision import datasets, transforms
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
SEED = 51
torch.random.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
Load the CIFAR10 dataset bellow. You can use either torchvision.datasets.CIFAR10
or sklearn.datasets.fetch_openml()
or any other way to load the dataset.
cifar_root = 'cifir10'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(cifar_root, train=True, download=True, transform=transform)
trainsize = int(len(trainset) * 0.9)
valsize = len(trainset) - trainsize
trainset, valset = torch.utils.data.random_split(trainset, [trainsize, valsize])
testset = torchvision.datasets.CIFAR10(cifar_root, train=False, download=True, transform=transform)
Write your CNN model below using torch.nn
modules. Feel free to add extra cells.
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential( # B 3 32 32
nn.Conv2d(3, 16, 3, padding=1), # B 16 32 32
nn.ReLU(),
nn.Conv2d(16, 16, 3, padding=1), # B 16 32 32
nn.ReLU(),
nn.MaxPool2d(2, 2), # B 16 16 16
nn.Dropout2d(p=0.2),
nn.BatchNorm2d(16),
)
self.fc = nn.Sequential(
nn.Flatten(), # B 4096
nn.Linear(4096, 10), # B 10
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.fc(x)
return x
Write the training process below. Instantiate your model, Create an optimizer function such as Adam or SGD, and write your train/validation loop. Then train your model until it converges. Feel free to add extra cells.
class Trainer:
CHECKPOINT = 'cifar-cnn.pth'
CLASSES = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer,
trainloader: DataLoader, valloader: DataLoader, testloader: DataLoader,
device: torch.device) -> None:
self.__model = model
self.__optimizer = optimizer
self.__trainloader = trainloader
self.__valloader = valloader
self.__testloader = testloader
self.__device = device
self.__train_loss = []
self.__train_accuracy = []
self.__val_loss = []
self.__val_accuracy = []
self.__best_val_acc = 0
def __forward_model(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, int]:
x = x.to(self.__device)
y = y.to(self.__device)
s = self.__model(x)
loss = F.cross_entropy(s, y)
true = int((s.argmax(dim=-1) == y).sum())
return loss, true
def __train_epoch(self) -> Tuple[float, float]:
epoch_loss = 0
epoch_accuracy = 0
n = 0
self.__model.train()
with tqdm(enumerate(self.__trainloader), total=len(self.__trainloader)) as pbar:
for i, (x, y) in pbar:
loss, true = self.__forward_model(x, y)
epoch_loss += float(loss)
epoch_accuracy += true
n += len(x)
loss.backward()
self.__optimizer.step()
self.__optimizer.zero_grad()
pbar.set_description(f'train loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')
epoch_loss /= len(self.__trainloader)
epoch_accuracy *= 100.0 / n
return epoch_loss, epoch_accuracy
def __eval_epoch(self, dataloader: DataLoader) -> Tuple[float, float]:
epoch_loss = 0
epoch_accuracy = 0
n = 0
self.__model.eval()
with torch.no_grad(), tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
for i, (x, y) in pbar:
loss, true = self.__forward_model(x, y)
epoch_loss += float(loss)
epoch_accuracy += true
n += len(x)
pbar.set_description(f'val loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')
epoch_loss /= len(dataloader)
epoch_accuracy *= 100.0 / n
return epoch_loss, epoch_accuracy
def __val_epoch(self) -> Tuple[float, float]:
return self.__eval_epoch(self.__valloader)
def test(self) -> Tuple[float, float]:
return self.__eval_epoch(self.__testloader)
def save(self) -> None:
torch.save(self.__model.state_dict(), self.CHECKPOINT)
def load(self) -> None:
self.__model.load_state_dict(torch.load(self.CHECKPOINT))
def train(self, max_epochs: int) -> None:
if len(self.__train_loss) >= max_epochs:
return
for e in range(len(self.__train_loss), max_epochs):
start_time = time()
train_loss, train_accuracy = self.__train_epoch()
val_loss, val_accuracy = self.__val_epoch()
end_time = time()
if val_accuracy > self.__best_val_acc:
self.__best_val_acc = val_accuracy
print(f'Saving checkpoint.')
self.save()
print(f'Epoch {e+1} finished in {end_time - start_time:.2f}s. Train [loss: {train_loss:.3e}, acc: {train_accuracy:.2f}%] - Val [loss: {val_loss:.3e}, acc: {val_accuracy:.2f}%]')
self.__train_loss.append(train_loss)
self.__train_accuracy.append(train_accuracy)
self.__val_loss.append(val_loss)
self.__val_accuracy.append(val_accuracy)
def plot_curves(self) -> None:
fig = plt.figure(figsize=(15, 7))
# loss
ax = fig.add_subplot(121)
ax.set_title('Loss / Epoch')
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')
plt.plot(self.__train_loss, label='Train', color='green', linewidth=3)
plt.plot(self.__val_loss, label='Validation', color='red', linewidth=3)
plt.legend()
# acc
ax = fig.add_subplot(122)
ax.set_title('Accuracy / Epoch')
ax.set_ylabel('Accuracy')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')
plt.plot(self.__train_accuracy, label='Train', color='green', linewidth=3)
plt.plot(self.__val_accuracy, label='Validation', color='red', linewidth=3)
plt.legend()
def predict(self) -> Iterable[int]:
self.__model.eval()
with torch.no_grad(), tqdm(self.__testloader, total=len(self.__testloader)) as pbar:
for x, _ in pbar:
yield from self.__model(x.to(self.__device)).argmax(dim=-1)
def draw_misclassified(self, n: int) -> None:
p = list(self.predict())
missed_idx = np.random.choice([i for i, (x, y) in enumerate(self.__testloader.dataset) if p[i] != y], size=n)
missed_samples = [self.__testloader.dataset[i] for i in missed_idx]
wrong_labels = [p[i] for i in missed_idx]
correct_labels = [s[1] for s in missed_samples]
images = [s[0].permute((1, 2, 0)) for s in missed_samples]
texts = [f'{self.CLASSES[correct]} -> {self.CLASSES[missed]}' for missed, correct in zip(wrong_labels, correct_labels)]
columns = 10
rows = int(np.ceil(n / columns))
fig = plt.figure(figsize=(2 * columns, 2 * rows))
for i in range(columns * rows):
ax = fig.add_subplot(rows, columns, i + 1)
ax.set_title(texts[i])
ax.set_aspect('equal')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow((images[i] - images[i].min()) / (images[i].max() - images[i].min()))
cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
cax.get_xaxis().set_visible(False)
cax.get_yaxis().set_visible(False)
cax.set_frame_on(False)
def plot_cm(self) -> None:
y_true = [label for _, label in self.__testloader.dataset]
y_pred = [int(label) for label in self.predict()]
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
B = 512
EB = 1024
trainloader = DataLoader(trainset, batch_size=B, shuffle=True)
valloader = DataLoader(valset, batch_size=EB, shuffle=False)
testloader = DataLoader(testset, batch_size=EB, shuffle=False)
net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
trainer = Trainer(net, optimizer, trainloader, valloader, testloader, device)
net
trainer.train(100)
Draw two diagrams for train and validat ion, one showing loss of each epoch, and another showing accuracy of each epoch.
trainer.plot_curves()
plt.show()
Evaluate the best epoch's model (according to the validation accuracy) on the test set, and report the accuracy. Is your model overfitted?
trainer.load()
_, accuracy = trainer.test()
accuracy
Draw 20 misclassified images from test set with expected and predicted labels.
trainer.draw_misclassified(20)
Plot the confusion matrix for the test set.
trainer.plot_cm()
plt.show()