Classification with Transfer Learning: A solution
CIFAR10 classification with transfer learning on Inception V3.
In this task, you should train Inception V3 with pytorch using transfer learning to classify CIFAR10 dataset.
from typing import Tuple, Iterable, List
from tqdm import tqdm
from time import time
import re
import torchvision
from torchvision import datasets, transforms, models
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
SEED = 51
torch.random.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
Load the CIFAR10 dataset bellow. You can use either torchvision.datasets.CIFAR10
or sklearn.datasets.fetch_openml()
or any other way to load the dataset.
cifar_root = 'cifir10'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((299, 299)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(cifar_root, train=True, download=True, transform=transform)
trainsize = int(len(trainset) * 0.9)
valsize = len(trainset) - trainsize
trainset, valset = torch.utils.data.random_split(trainset, [trainsize, valsize])
testset = torchvision.datasets.CIFAR10(cifar_root, train=False, download=True, transform=transform)
Instantiate Inception V3 model (pretrained on imagenet) from torchvision
's model zoo.
net = models.inception_v3(pretrained=True).to(device)
net
Imagenet has 1000 classes, but CIFAR10 has 10. So, the decider layers of the model must be changed so that we can use the model for CIFAR10. Therefore, adapt net.fc
and net.AuxLogits.fc
with CIFAR10.
net.fc = nn.Linear(2048, 10).to(device)
net.AuxLogits.fc = nn.Linear(768, 10).to(device)
In order to apply transfer learning, freeze all layers except the deciders. Freezing consist of disabling optimization by disabling grad calculation. Also in batch normalization layers, updating mooving average and variance must be disabled. Note that you must later filter out the frozen parameter for optimizer.
class TrainDisable:
def __init__(self, train) -> None:
self.old_train = train
def __call__(self, *args, **kwargs) -> None:
pass
def freeze(m: nn.Module, pattern: str) -> None:
regex = re.compile(pattern)
frozen_names = set()
for name, p in m.named_parameters():
if regex.fullmatch(name):
# print(f'grad off: {name}')
p.requires_grad = False
frozen_names.add(name[:name.rfind('.')])
for name, module in m.named_modules():
if (name in frozen_names) and isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
# print(f'bn freeze: {name}')
module.eval()
module.train = TrainDisable(module.train)
pattern = r'(?!^.*fc\..*$)(^.*$)' # Anything except fc
freeze(net, pattern)
Create the optimizer filtering out freezed parameters.
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-4)
Write your train/validation loop bellow. Then train the model until it converges. Feel free to add extra cells.
class Trainer:
CHECKPOINT = 'cifar-inceptionv3.pth'
CLASSES = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer,
trainloader: DataLoader, valloader: DataLoader, testloader: DataLoader,
device: torch.device, alpha: float = 0.2) -> None:
self.__model = model
self.__optimizer = optimizer
self.__trainloader = trainloader
self.__valloader = valloader
self.__testloader = testloader
self.__device = device
self.__train_loss = []
self.__train_accuracy = []
self.__val_loss = []
self.__val_accuracy = []
self.__best_val_acc = 0
self.__alpha = alpha
def __train_epoch(self) -> Tuple[float, float]:
epoch_loss = 0
epoch_accuracy = 0
n = 0
self.__model.train()
with tqdm(enumerate(self.__trainloader), total=len(self.__trainloader)) as pbar:
for i, (x, y) in pbar:
x = x.to(self.__device)
y = y.to(self.__device)
s, s_a = self.__model(x)
loss = F.cross_entropy(s, y) + self.__alpha * F.cross_entropy(s_a, y)
true = int((s.argmax(dim=-1) == y).sum())
epoch_loss += float(loss)
epoch_accuracy += true
n += len(x)
loss.backward()
self.__optimizer.step()
self.__optimizer.zero_grad()
pbar.set_description(f'train loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')
epoch_loss /= len(self.__trainloader)
epoch_accuracy *= 100.0 / n
return epoch_loss, epoch_accuracy
def __eval_epoch(self, dataloader: DataLoader) -> Tuple[float, float]:
epoch_loss = 0
epoch_accuracy = 0
n = 0
self.__model.eval()
with torch.no_grad(), tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
for i, (x, y) in pbar:
x = x.to(self.__device)
y = y.to(self.__device)
s = self.__model(x)
loss = F.cross_entropy(s, y)
true = int((s.argmax(dim=-1) == y).sum())
epoch_loss += float(loss)
epoch_accuracy += true
n += len(x)
pbar.set_description(f'val loss: {epoch_loss / (i + 1):.3e} - accuracy: {epoch_accuracy * 100.0 / n:.2f}%')
epoch_loss /= len(dataloader)
epoch_accuracy *= 100.0 / n
return epoch_loss, epoch_accuracy
def __val_epoch(self) -> Tuple[float, float]:
return self.__eval_epoch(self.__valloader)
def test(self) -> Tuple[float, float]:
return self.__eval_epoch(self.__testloader)
def save(self) -> None:
torch.save(self.__model.state_dict(), self.CHECKPOINT)
def load(self) -> None:
self.__model.load_state_dict(torch.load(self.CHECKPOINT))
def train(self, max_epochs: int) -> None:
if len(self.__train_loss) >= max_epochs:
return
for e in range(len(self.__train_loss), max_epochs):
start_time = time()
train_loss, train_accuracy = self.__train_epoch()
val_loss, val_accuracy = self.__val_epoch()
end_time = time()
if val_accuracy > self.__best_val_acc:
self.__best_val_acc = val_accuracy
print(f'Saving checkpoint.')
self.save()
print(f'Epoch {e+1} finished in {end_time - start_time:.2f}s. Train [loss: {train_loss:.3e}, acc: {train_accuracy:.2f}%] - Val [loss: {val_loss:.3e}, acc: {val_accuracy:.2f}%]')
self.__train_loss.append(train_loss)
self.__train_accuracy.append(train_accuracy)
self.__val_loss.append(val_loss)
self.__val_accuracy.append(val_accuracy)
def plot_curves(self) -> None:
fig = plt.figure(figsize=(15, 7))
# loss
ax = fig.add_subplot(121)
ax.set_title('Loss / Epoch')
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')
plt.plot(self.__train_loss, label='Train', color='green', linewidth=3)
plt.plot(self.__val_loss, label='Validation', color='red', linewidth=3)
plt.legend()
# acc
ax = fig.add_subplot(122)
ax.set_title('Accuracy / Epoch')
ax.set_ylabel('Accuracy')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')
plt.plot(self.__train_accuracy, label='Train', color='green', linewidth=3)
plt.plot(self.__val_accuracy, label='Validation', color='red', linewidth=3)
plt.legend()
def predict(self) -> Iterable[int]:
self.__model.eval()
with torch.no_grad(), tqdm(self.__testloader, total=len(self.__testloader)) as pbar:
for x, _ in pbar:
yield from self.__model(x.to(self.__device)).argmax(dim=-1)
def draw_misclassified(self, n: int) -> None:
p = list(self.predict())
missed_idx = np.random.choice([i for i, (x, y) in enumerate(self.__testloader.dataset) if p[i] != y], size=n)
missed_samples = [self.__testloader.dataset[i] for i in missed_idx]
wrong_labels = [p[i] for i in missed_idx]
correct_labels = [s[1] for s in missed_samples]
images = [s[0].permute((1, 2, 0)) for s in missed_samples]
texts = [f'{self.CLASSES[correct]} -> {self.CLASSES[missed]}' for missed, correct in zip(wrong_labels, correct_labels)]
columns = 10
rows = int(np.ceil(n / columns))
fig = plt.figure(figsize=(2 * columns, 2 * rows))
for i in range(columns * rows):
ax = fig.add_subplot(rows, columns, i + 1)
ax.set_title(texts[i])
ax.set_aspect('equal')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow((images[i] - images[i].min()) / (images[i].max() - images[i].min()))
cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
cax.get_xaxis().set_visible(False)
cax.get_yaxis().set_visible(False)
cax.set_frame_on(False)
def plot_cm(self) -> None:
y_true = [label for _, label in self.__testloader.dataset]
y_pred = [int(label) for label in self.predict()]
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
B = 128
EB = 256
trainloader = DataLoader(trainset, batch_size=B, shuffle=True)
valloader = DataLoader(valset, batch_size=EB, shuffle=False)
testloader = DataLoader(testset, batch_size=EB, shuffle=False)
trainer = Trainer(net, optimizer, trainloader, valloader, testloader, device)
net
trainer.train(20)
Draw two diagrams for train and validat ion, one showing loss of each epoch, and another showing accuracy of each epoch.
trainer.plot_curves()
plt.show()
Evaluate the best epoch's model (according to the validation accuracy) on the test set, and report the accuracy. Is your model overfitted?
trainer.load()
_, accuracy = trainer.test()
accuracy
Draw 20 misclassified images from test set with expected and predicted labels.
trainer.draw_misclassified(20)
Plot the confusion matrix for the test set.
trainer.plot_cm()
plt.show()