Conditional GAN

Generating right half of an MNIST digit given the left half.

In this task, you should train a conditional GAN to generate the right half of MNIST images, given the left half.

In [ ]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from time import time
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

Load the dataset

Load the MNIST dataset bellow. You can use either torchvision.datasets.MNIST or sklearn.datasets.fetch_openml() or any other way to load the dataset.

In [ ]:
X, y = fetch_openml(name='mnist_784', return_X_y=True, as_frame=False)
In [ ]:
y = y.astype('int32')
In [ ]:
X = X.reshape((X.shape[0], 28, 28)).astype('float32')
X.shape
Out[ ]:
(70000, 28, 28)
In [ ]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1)
(X_train.shape, X_test.shape, X_val.shape), (y_train.shape, y_test.shape, y_val.shape)
Out[ ]:
(((50400, 28, 28), (14000, 28, 28), (5600, 28, 28)),
 ((50400,), (14000,), (5600,)))
In [ ]:
images_train = X_train / 255
images_val = X_val / 255
images_test = X_test / 255

Design your model

Implement your GAN below using torch.nn modules. Feel free to add extra cells. The generator should construct right half of an image given the left half of it and a noise. The general structure of this GAN is shown below.

In [ ]:
class ImageConv(nn.Sequential):
    def __init__(self):
        super().__init__(                                   # B 1   28  28
            nn.BatchNorm2d(num_features=1),
            nn.Conv2d(1, 64, 3, stride=1, padding=3),       # B 64  32  32  
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),      # B 64  32  32
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 64  16  16
            nn.Dropout2d(0.2, inplace=True),
            nn.BatchNorm2d(num_features=64),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),     # B 128 16  16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),    # B 128 16  16
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 128 8   8
            nn.Dropout2d(0.2, inplace=True),
            nn.BatchNorm2d(num_features=128),
            nn.Conv2d(128, 256, 3, stride=1, padding=1),    # B 256 8   8
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),    # B 256 8   8
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 256 4   4
            nn.Dropout2d(0.2, inplace=True),
            nn.BatchNorm2d(num_features=256),
            nn.Conv2d(256, 512, 3, stride=1, padding=1),    # B 512 4   4
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, 3, stride=1, padding=1),    # B 512 4   4
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 512 2   2
            nn.Dropout2d(0.2, inplace=True)
        )
In [ ]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(                          # B 1   28  28
            nn.BatchNorm2d(num_features=1),
            nn.Conv2d(1, 64, 3, stride=1, padding=3),       # B 64  32  32  
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),      # B 64  32  32
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 64  16  16
            nn.Dropout2d(0.2, inplace=True),
            nn.BatchNorm2d(num_features=64),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),     # B 128 16  16
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),    # B 128 16  16
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 128 8   8
            nn.Dropout2d(0.2, inplace=True),
            nn.BatchNorm2d(num_features=128),
            nn.Conv2d(128, 256, 3, stride=1, padding=1),    # B 256 8   8
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),    # B 256 8   8
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 256 4   4
            nn.Dropout2d(0.2, inplace=True),
            nn.BatchNorm2d(num_features=256),
            nn.Conv2d(256, 512, 3, stride=1, padding=1),    # B 512 4   4
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 3, stride=1, padding=1),    # B 512 4   4
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2, padding=0),           # B 512 2   2
            nn.Dropout2d(0.2, inplace=True)
        )

        self.fc = nn.Sequential(
            nn.Linear(2048, 512),              # B 512
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(512, 256),               # B 256
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(256, 1),                 # B 1
            nn.Sigmoid()
        )
    
    def forward(self, full_image):
        # full_image        B   28  28
        
        full_image = full_image.unsqueeze(1)    # B 1   28  28
        out = self.conv(full_image)             # B 512 2   2
        out = out.flatten(1)                    # B 2048
        out = self.fc(out)                      # B 1
        out = out.squeeze()                     # B
        return out
In [ ]:
class Generator(nn.Module):
    def __init__(self, conv:nn.Module):
        super().__init__()

        self.conv = conv

        self.noise_fc = nn.Sequential(                                          # B 100
            nn.Linear(100, 1024),                                               # B 1024
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.dc = nn.Sequential(                                                # B 1024    2   1
            nn.BatchNorm2d(1024),
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1, bias=False),  # B 512     4   2
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, 3, stride=1, padding=1),                        # B 512     4   2
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(512),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),   # B 256     8   4
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),                        # B 256     8   4
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),   # B 128     16  8
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),                        # B 128     16  8
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),    # B 64      32  16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, 3, stride=1, padding=0),                          # B 64      30  14
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, stride=1, padding=(0, 1)),                     # B 64      28  14
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),                          # B 64      28  14
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),                           # B 1       28  14
            nn.ReLU(inplace=True),
        )

    def forward(self, z, left):
        # z     B   100
        # left  B   28  14

        left = left.unsqueeze(1)                                                            # B 1   28  14

        with torch.no_grad():
            condition_features = self.conv(left)                                            # B 512 2   1
        condition_features = condition_features.flatten(1).unsqueeze(2)                     # B 1024    1

        noise_features = self.noise_fc(z).unsqueeze(2)                                      # B 1024    1

        dc_input = torch.cat((condition_features, noise_features), dim=2).unsqueeze(3)      # B 1024    2   1
        constructed = self.dc(dc_input)                                                     # B 1   28  14
        constructed = constructed.squeeze(1)                                                # B 28  14

        return constructed.clamp(0, 1)

Train your GAN

Write the training process below. Feel free to add extra cells.

Pretrained Convolutional encoder with an Autoencoder

First, we train an autoencoder to use its trained encoder in generator and discriminator encoders.

In [ ]:
def get_batches(X, y=None, batch_size=128, shuffle=True):
    if y is not None:
        assert X.shape[0] == y.shape[0]

    num_batches = int(np.ceil(X.shape[0] * 1.0 / batch_size))

    if shuffle:
        indices = np.random.permutation(X.shape[0])
        X = X[indices]
        if y is not None:
            y = y[indices]

    for batch in range(num_batches):
        start = batch * batch_size
        end = min((batch + 1) * batch_size, X.shape[0])
        yield (batch, X[start:end], y[start:end]) if y is not None else (batch, X[start:end])
In [ ]:
def draw(images, texts, columns=1, rows=1):
    fig = plt.figure(figsize=(2 * columns, 2 * rows))

    for i in range(columns * rows):
        ax = fig.add_subplot(rows, columns, i + 1)
        ax.set_title(texts[i])
        ax.set_aspect('equal')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        plt.imshow(images[i].reshape((28, 28)))

    cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
    cax.get_xaxis().set_visible(False)
    cax.get_yaxis().set_visible(False)
    cax.set_frame_on(False)
    plt.show()
In [ ]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ImageConv()

        self.decoder = nn.Sequential(                                           # B 512 2   2
            nn.BatchNorm2d(512),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),   # B 256 4   4
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),                        # B 256 4   4
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),   # B 128 8   8
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),                        # B 128 8   8
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),    # B 64  16  16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),                          # B 64  16  16
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1, bias=False),     # B 64  32  32
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 32, 3, stride=1, padding=0),                          # B 64  30  30
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, 3, stride=1, padding=0),                            # B 1   28  28
            nn.ReLU(inplace=True)
        )
    
    def forward(self, b_x):
        # b_x   B   28  28

        b_x = b_x.unsqueeze(1)              # B 1   28  28
        encoded = self.encoder(b_x)         # B 512 2   2
        decoded = self.decoder(encoded)     # B 1   28  28
        decoded = decoded.squeeze(1)

        return decoded
In [ ]:
def train_autoencoder(model, optimizer, X, batch_size):
    epoch_loss = 0
    iter = 0

    model.train()
    for iter, b_X in get_batches(X, batch_size=batch_size):
        images = torch.tensor(b_X, device='cuda')
        prediction = model(images)

        loss = F.mse_loss(prediction, images)

        epoch_loss += float(loss)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if (iter + 1) % 50 == 0:
            print(f'[Train] Iteration {iter + 1:3d} - loss: {epoch_loss / (iter + 1):.2e}')
    
    epoch_loss /= (iter + 1)

    return epoch_loss
In [ ]:
def evaluate_autoencoder(model, X, batch_size, return_predictions=False):
    epoch_loss = 0
    iter = 0
    predictions = []

    with torch.no_grad():
        model.eval()
        for iter, b_X in get_batches(X, batch_size=batch_size, shuffle=False):
            images = torch.tensor(b_X, device='cuda')
            prediction = model(images)
            predictions.append(prediction.cpu().numpy())

            loss = F.mse_loss(prediction, images)

            epoch_loss += float(loss)

            if (iter + 1) % 50 == 0:
                print(f'[Valid] Iteration {iter + 1:3d} - loss: {epoch_loss / (iter + 1):.2e}')
        
    epoch_loss /= (iter + 1)
    if return_predictions:
        return np.concatenate(predictions, axis=0), epoch_loss

    return epoch_loss
In [ ]:
autoencoder = Autoencoder().cuda()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-4)
In [ ]:
batch_size = 256
epochs = 200

train_loss = []
val_loss = []

best_val_loss = float('inf')

for e in range(epochs):

    start_time = time()

    epoch_train_loss = train_autoencoder(autoencoder, optimizer, images_train, batch_size)
    epoch_val_loss = evaluate_autoencoder(autoencoder, images_val, batch_size)

    end_time = time()

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(autoencoder.encoder.state_dict(), 'drive/MyDrive/mnist-autoencoder-conv-model.pt')
    
    print(f'Epoch {e+1:3} finished in {end_time - start_time:.2f}s - Train loss: {epoch_train_loss:.2e} - Val loss: {epoch_val_loss:.2e}')

    train_loss.append(epoch_train_loss)
    val_loss.append(epoch_val_loss)
[Train] Iteration  50 - loss: 7.92e-02
[Train] Iteration 100 - loss: 6.16e-02
[Train] Iteration 150 - loss: 5.12e-02
Epoch   1 finished in 31.71s - Train loss: 4.46e-02 - Val loss: 1.91e-02
[Train] Iteration  50 - loss: 1.93e-02
[Train] Iteration 100 - loss: 1.81e-02
[Train] Iteration 150 - loss: 1.71e-02
Epoch   2 finished in 31.97s - Train loss: 1.63e-02 - Val loss: 1.19e-02
[Train] Iteration  50 - loss: 1.29e-02
[Train] Iteration 100 - loss: 1.24e-02
[Train] Iteration 150 - loss: 1.21e-02
Epoch   3 finished in 32.34s - Train loss: 1.17e-02 - Val loss: 8.86e-03
[Train] Iteration  50 - loss: 1.05e-02
[Train] Iteration 100 - loss: 1.02e-02
[Train] Iteration 150 - loss: 9.94e-03
Epoch   4 finished in 32.81s - Train loss: 9.77e-03 - Val loss: 8.34e-03
[Train] Iteration  50 - loss: 8.84e-03
[Train] Iteration 100 - loss: 8.76e-03
[Train] Iteration 150 - loss: 8.64e-03
Epoch   5 finished in 32.99s - Train loss: 8.53e-03 - Val loss: 7.10e-03
[Train] Iteration  50 - loss: 8.03e-03
[Train] Iteration 100 - loss: 7.90e-03
[Train] Iteration 150 - loss: 7.78e-03
Epoch   6 finished in 32.77s - Train loss: 7.72e-03 - Val loss: 6.28e-03
[Train] Iteration  50 - loss: 7.21e-03
[Train] Iteration 100 - loss: 7.15e-03
[Train] Iteration 150 - loss: 7.09e-03
Epoch   7 finished in 32.77s - Train loss: 7.03e-03 - Val loss: 6.08e-03
[Train] Iteration  50 - loss: 6.80e-03
[Train] Iteration 100 - loss: 6.74e-03
[Train] Iteration 150 - loss: 6.63e-03
Epoch   8 finished in 32.91s - Train loss: 6.55e-03 - Val loss: 5.41e-03
[Train] Iteration  50 - loss: 6.21e-03
[Train] Iteration 100 - loss: 6.22e-03
[Train] Iteration 150 - loss: 6.20e-03
Epoch   9 finished in 32.82s - Train loss: 6.14e-03 - Val loss: 4.89e-03
[Train] Iteration  50 - loss: 5.91e-03
[Train] Iteration 100 - loss: 5.94e-03
[Train] Iteration 150 - loss: 5.88e-03
Epoch  10 finished in 32.74s - Train loss: 5.84e-03 - Val loss: 4.61e-03
[Train] Iteration  50 - loss: 5.60e-03
[Train] Iteration 100 - loss: 5.56e-03
[Train] Iteration 150 - loss: 5.53e-03
Epoch  11 finished in 32.76s - Train loss: 5.52e-03 - Val loss: 4.45e-03
[Train] Iteration  50 - loss: 5.26e-03
[Train] Iteration 100 - loss: 5.24e-03
[Train] Iteration 150 - loss: 5.23e-03
Epoch  12 finished in 32.82s - Train loss: 5.21e-03 - Val loss: 4.37e-03
[Train] Iteration  50 - loss: 5.12e-03
[Train] Iteration 100 - loss: 5.06e-03
[Train] Iteration 150 - loss: 5.04e-03
Epoch  13 finished in 32.77s - Train loss: 5.01e-03 - Val loss: 4.05e-03
[Train] Iteration  50 - loss: 4.94e-03
[Train] Iteration 100 - loss: 4.87e-03
[Train] Iteration 150 - loss: 4.85e-03
Epoch  14 finished in 32.80s - Train loss: 4.86e-03 - Val loss: 4.19e-03
[Train] Iteration  50 - loss: 4.68e-03
[Train] Iteration 100 - loss: 4.63e-03
[Train] Iteration 150 - loss: 4.62e-03
Epoch  15 finished in 32.79s - Train loss: 4.64e-03 - Val loss: 4.13e-03
[Train] Iteration  50 - loss: 4.55e-03
[Train] Iteration 100 - loss: 4.51e-03
[Train] Iteration 150 - loss: 4.54e-03
Epoch  16 finished in 32.88s - Train loss: 4.52e-03 - Val loss: 3.59e-03
[Train] Iteration  50 - loss: 4.38e-03
[Train] Iteration 100 - loss: 4.39e-03
[Train] Iteration 150 - loss: 4.38e-03
Epoch  17 finished in 32.81s - Train loss: 4.37e-03 - Val loss: 3.70e-03
[Train] Iteration  50 - loss: 4.30e-03
[Train] Iteration 100 - loss: 4.23e-03
[Train] Iteration 150 - loss: 4.23e-03
Epoch  18 finished in 32.84s - Train loss: 4.21e-03 - Val loss: 3.50e-03
[Train] Iteration  50 - loss: 4.12e-03
[Train] Iteration 100 - loss: 4.10e-03
[Train] Iteration 150 - loss: 4.09e-03
Epoch  19 finished in 32.81s - Train loss: 4.09e-03 - Val loss: 3.68e-03
[Train] Iteration  50 - loss: 3.99e-03
[Train] Iteration 100 - loss: 4.01e-03
[Train] Iteration 150 - loss: 4.01e-03
Epoch  20 finished in 32.78s - Train loss: 3.99e-03 - Val loss: 3.18e-03
[Train] Iteration  50 - loss: 3.96e-03
[Train] Iteration 100 - loss: 3.94e-03
[Train] Iteration 150 - loss: 3.91e-03
Epoch  21 finished in 32.76s - Train loss: 3.92e-03 - Val loss: 3.38e-03
[Train] Iteration  50 - loss: 3.85e-03
[Train] Iteration 100 - loss: 3.81e-03
[Train] Iteration 150 - loss: 3.81e-03
Epoch  22 finished in 32.78s - Train loss: 3.81e-03 - Val loss: 3.03e-03
[Train] Iteration  50 - loss: 3.78e-03
[Train] Iteration 100 - loss: 3.72e-03
[Train] Iteration 150 - loss: 3.72e-03
Epoch  23 finished in 32.77s - Train loss: 3.71e-03 - Val loss: 3.46e-03
[Train] Iteration  50 - loss: 3.64e-03
[Train] Iteration 100 - loss: 3.66e-03
[Train] Iteration 150 - loss: 3.64e-03
Epoch  24 finished in 32.77s - Train loss: 3.64e-03 - Val loss: 3.04e-03
[Train] Iteration  50 - loss: 3.51e-03
[Train] Iteration 100 - loss: 3.56e-03
[Train] Iteration 150 - loss: 3.54e-03
Epoch  25 finished in 32.75s - Train loss: 3.52e-03 - Val loss: 3.02e-03
[Train] Iteration  50 - loss: 3.46e-03
[Train] Iteration 100 - loss: 3.45e-03
[Train] Iteration 150 - loss: 3.45e-03
Epoch  26 finished in 32.92s - Train loss: 3.44e-03 - Val loss: 2.96e-03
[Train] Iteration  50 - loss: 3.36e-03
[Train] Iteration 100 - loss: 3.36e-03
[Train] Iteration 150 - loss: 3.39e-03
Epoch  27 finished in 32.94s - Train loss: 3.39e-03 - Val loss: 2.85e-03
[Train] Iteration  50 - loss: 3.28e-03
[Train] Iteration 100 - loss: 3.28e-03
[Train] Iteration 150 - loss: 3.29e-03
Epoch  28 finished in 32.84s - Train loss: 3.28e-03 - Val loss: 2.67e-03
[Train] Iteration  50 - loss: 3.26e-03
[Train] Iteration 100 - loss: 3.25e-03
[Train] Iteration 150 - loss: 3.24e-03
Epoch  29 finished in 32.81s - Train loss: 3.25e-03 - Val loss: 2.65e-03
[Train] Iteration  50 - loss: 3.13e-03
[Train] Iteration 100 - loss: 3.15e-03
[Train] Iteration 150 - loss: 3.16e-03
Epoch  30 finished in 32.78s - Train loss: 3.15e-03 - Val loss: 2.58e-03
[Train] Iteration  50 - loss: 3.09e-03
[Train] Iteration 100 - loss: 3.09e-03
[Train] Iteration 150 - loss: 3.11e-03
Epoch  31 finished in 32.82s - Train loss: 3.11e-03 - Val loss: 2.52e-03
[Train] Iteration  50 - loss: 3.06e-03
[Train] Iteration 100 - loss: 3.06e-03
[Train] Iteration 150 - loss: 3.06e-03
Epoch  32 finished in 32.77s - Train loss: 3.06e-03 - Val loss: 2.61e-03
[Train] Iteration  50 - loss: 3.01e-03
[Train] Iteration 100 - loss: 3.02e-03
[Train] Iteration 150 - loss: 3.02e-03
Epoch  33 finished in 32.81s - Train loss: 3.00e-03 - Val loss: 2.41e-03
[Train] Iteration  50 - loss: 3.00e-03
[Train] Iteration 100 - loss: 3.01e-03
[Train] Iteration 150 - loss: 2.99e-03
Epoch  34 finished in 32.77s - Train loss: 2.99e-03 - Val loss: 2.74e-03
[Train] Iteration  50 - loss: 2.90e-03
[Train] Iteration 100 - loss: 2.91e-03
[Train] Iteration 150 - loss: 2.90e-03
Epoch  35 finished in 32.74s - Train loss: 2.90e-03 - Val loss: 2.38e-03
[Train] Iteration  50 - loss: 2.91e-03
[Train] Iteration 100 - loss: 2.89e-03
[Train] Iteration 150 - loss: 2.89e-03
Epoch  36 finished in 32.74s - Train loss: 2.90e-03 - Val loss: 2.32e-03
[Train] Iteration  50 - loss: 2.84e-03
[Train] Iteration 100 - loss: 2.83e-03
[Train] Iteration 150 - loss: 2.83e-03
Epoch  37 finished in 32.76s - Train loss: 2.83e-03 - Val loss: 2.25e-03
[Train] Iteration  50 - loss: 2.76e-03
[Train] Iteration 100 - loss: 2.75e-03
[Train] Iteration 150 - loss: 2.76e-03
Epoch  38 finished in 32.89s - Train loss: 2.77e-03 - Val loss: 2.33e-03
[Train] Iteration  50 - loss: 2.75e-03
[Train] Iteration 100 - loss: 2.75e-03
[Train] Iteration 150 - loss: 2.74e-03
Epoch  39 finished in 32.84s - Train loss: 2.74e-03 - Val loss: 2.28e-03
[Train] Iteration  50 - loss: 2.73e-03
[Train] Iteration 100 - loss: 2.75e-03
[Train] Iteration 150 - loss: 2.76e-03
Epoch  40 finished in 32.78s - Train loss: 2.74e-03 - Val loss: 2.13e-03
[Train] Iteration  50 - loss: 2.60e-03
[Train] Iteration 100 - loss: 2.64e-03
[Train] Iteration 150 - loss: 2.64e-03
Epoch  41 finished in 32.82s - Train loss: 2.65e-03 - Val loss: 2.31e-03
[Train] Iteration  50 - loss: 2.57e-03
[Train] Iteration 100 - loss: 2.59e-03
[Train] Iteration 150 - loss: 2.60e-03
Epoch  42 finished in 32.84s - Train loss: 2.60e-03 - Val loss: 2.20e-03
[Train] Iteration  50 - loss: 2.60e-03
[Train] Iteration 100 - loss: 2.61e-03
[Train] Iteration 150 - loss: 2.61e-03
Epoch  43 finished in 32.81s - Train loss: 2.60e-03 - Val loss: 2.16e-03
[Train] Iteration  50 - loss: 2.52e-03
[Train] Iteration 100 - loss: 2.51e-03
[Train] Iteration 150 - loss: 2.53e-03
Epoch  44 finished in 32.77s - Train loss: 2.53e-03 - Val loss: 2.12e-03
[Train] Iteration  50 - loss: 2.52e-03
[Train] Iteration 100 - loss: 2.51e-03
[Train] Iteration 150 - loss: 2.51e-03
Epoch  45 finished in 32.77s - Train loss: 2.50e-03 - Val loss: 2.09e-03
[Train] Iteration  50 - loss: 2.47e-03
[Train] Iteration 100 - loss: 2.46e-03
[Train] Iteration 150 - loss: 2.46e-03
Epoch  46 finished in 32.75s - Train loss: 2.46e-03 - Val loss: 2.06e-03
[Train] Iteration  50 - loss: 2.52e-03
[Train] Iteration 100 - loss: 2.49e-03
[Train] Iteration 150 - loss: 2.48e-03
Epoch  47 finished in 32.81s - Train loss: 2.47e-03 - Val loss: 1.96e-03
[Train] Iteration  50 - loss: 2.43e-03
[Train] Iteration 100 - loss: 2.45e-03
[Train] Iteration 150 - loss: 2.45e-03
Epoch  48 finished in 32.75s - Train loss: 2.44e-03 - Val loss: 1.89e-03
[Train] Iteration  50 - loss: 2.37e-03
[Train] Iteration 100 - loss: 2.35e-03
[Train] Iteration 150 - loss: 2.37e-03
Epoch  49 finished in 32.88s - Train loss: 2.37e-03 - Val loss: 1.96e-03
[Train] Iteration  50 - loss: 2.34e-03
[Train] Iteration 100 - loss: 2.34e-03
[Train] Iteration 150 - loss: 2.36e-03
Epoch  50 finished in 32.82s - Train loss: 2.34e-03 - Val loss: 1.91e-03
[Train] Iteration  50 - loss: 2.36e-03
[Train] Iteration 100 - loss: 2.33e-03
[Train] Iteration 150 - loss: 2.33e-03
Epoch  51 finished in 32.89s - Train loss: 2.33e-03 - Val loss: 2.04e-03
[Train] Iteration  50 - loss: 2.32e-03
[Train] Iteration 100 - loss: 2.31e-03
[Train] Iteration 150 - loss: 2.31e-03
Epoch  52 finished in 32.91s - Train loss: 2.32e-03 - Val loss: 2.01e-03
[Train] Iteration  50 - loss: 2.27e-03
[Train] Iteration 100 - loss: 2.26e-03
[Train] Iteration 150 - loss: 2.27e-03
Epoch  53 finished in 32.90s - Train loss: 2.27e-03 - Val loss: 1.80e-03
[Train] Iteration  50 - loss: 2.28e-03
[Train] Iteration 100 - loss: 2.24e-03
[Train] Iteration 150 - loss: 2.25e-03
Epoch  54 finished in 32.91s - Train loss: 2.25e-03 - Val loss: 2.06e-03
[Train] Iteration  50 - loss: 2.23e-03
[Train] Iteration 100 - loss: 2.23e-03
[Train] Iteration 150 - loss: 2.25e-03
Epoch  55 finished in 32.89s - Train loss: 2.24e-03 - Val loss: 1.79e-03
[Train] Iteration  50 - loss: 2.20e-03
[Train] Iteration 100 - loss: 2.22e-03
[Train] Iteration 150 - loss: 2.21e-03
Epoch  56 finished in 32.90s - Train loss: 2.22e-03 - Val loss: 1.73e-03
[Train] Iteration  50 - loss: 2.21e-03
[Train] Iteration 100 - loss: 2.18e-03
[Train] Iteration 150 - loss: 2.18e-03
Epoch  57 finished in 32.81s - Train loss: 2.18e-03 - Val loss: 1.80e-03
[Train] Iteration  50 - loss: 2.15e-03
[Train] Iteration 100 - loss: 2.17e-03
[Train] Iteration 150 - loss: 2.17e-03
Epoch  58 finished in 32.84s - Train loss: 2.17e-03 - Val loss: 1.66e-03
[Train] Iteration  50 - loss: 2.12e-03
[Train] Iteration 100 - loss: 2.11e-03
[Train] Iteration 150 - loss: 2.12e-03
Epoch  59 finished in 32.83s - Train loss: 2.13e-03 - Val loss: 1.74e-03
[Train] Iteration  50 - loss: 2.08e-03
[Train] Iteration 100 - loss: 2.10e-03
[Train] Iteration 150 - loss: 2.12e-03
Epoch  60 finished in 32.86s - Train loss: 2.12e-03 - Val loss: 1.64e-03
[Train] Iteration  50 - loss: 2.08e-03
[Train] Iteration 100 - loss: 2.07e-03
[Train] Iteration 150 - loss: 2.08e-03
Epoch  61 finished in 32.90s - Train loss: 2.08e-03 - Val loss: 1.72e-03
[Train] Iteration  50 - loss: 2.05e-03
[Train] Iteration 100 - loss: 2.06e-03
[Train] Iteration 150 - loss: 2.05e-03
Epoch  62 finished in 32.97s - Train loss: 2.05e-03 - Val loss: 1.68e-03
[Train] Iteration  50 - loss: 2.03e-03
[Train] Iteration 100 - loss: 2.05e-03
[Train] Iteration 150 - loss: 2.04e-03
Epoch  63 finished in 32.83s - Train loss: 2.04e-03 - Val loss: 1.68e-03
[Train] Iteration  50 - loss: 1.99e-03
[Train] Iteration 100 - loss: 1.99e-03
[Train] Iteration 150 - loss: 2.01e-03
Epoch  64 finished in 32.93s - Train loss: 2.03e-03 - Val loss: 1.84e-03
[Train] Iteration  50 - loss: 2.05e-03
[Train] Iteration 100 - loss: 2.05e-03
[Train] Iteration 150 - loss: 2.05e-03
Epoch  65 finished in 32.91s - Train loss: 2.04e-03 - Val loss: 1.64e-03
[Train] Iteration  50 - loss: 2.00e-03
[Train] Iteration 100 - loss: 1.99e-03
[Train] Iteration 150 - loss: 2.00e-03
Epoch  66 finished in 32.87s - Train loss: 1.99e-03 - Val loss: 1.55e-03
[Train] Iteration  50 - loss: 1.99e-03
[Train] Iteration 100 - loss: 1.99e-03
[Train] Iteration 150 - loss: 1.99e-03
Epoch  67 finished in 32.86s - Train loss: 1.98e-03 - Val loss: 1.55e-03
[Train] Iteration  50 - loss: 1.92e-03
[Train] Iteration 100 - loss: 1.93e-03
[Train] Iteration 150 - loss: 1.94e-03
Epoch  68 finished in 32.84s - Train loss: 1.95e-03 - Val loss: 1.57e-03
[Train] Iteration  50 - loss: 1.97e-03
[Train] Iteration 100 - loss: 1.96e-03
[Train] Iteration 150 - loss: 1.95e-03
Epoch  69 finished in 32.88s - Train loss: 1.96e-03 - Val loss: 1.50e-03
[Train] Iteration  50 - loss: 1.88e-03
[Train] Iteration 100 - loss: 1.89e-03
[Train] Iteration 150 - loss: 1.90e-03
Epoch  70 finished in 32.80s - Train loss: 1.91e-03 - Val loss: 1.58e-03
[Train] Iteration  50 - loss: 1.89e-03
[Train] Iteration 100 - loss: 1.89e-03
[Train] Iteration 150 - loss: 1.88e-03
Epoch  71 finished in 32.88s - Train loss: 1.89e-03 - Val loss: 1.54e-03
[Train] Iteration  50 - loss: 1.91e-03
[Train] Iteration 100 - loss: 1.88e-03
[Train] Iteration 150 - loss: 1.88e-03
Epoch  72 finished in 32.86s - Train loss: 1.87e-03 - Val loss: 1.49e-03
[Train] Iteration  50 - loss: 1.88e-03
[Train] Iteration 100 - loss: 1.89e-03
[Train] Iteration 150 - loss: 1.90e-03
Epoch  73 finished in 32.86s - Train loss: 1.90e-03 - Val loss: 1.48e-03
[Train] Iteration  50 - loss: 1.83e-03
[Train] Iteration 100 - loss: 1.84e-03
[Train] Iteration 150 - loss: 1.86e-03
Epoch  74 finished in 32.79s - Train loss: 1.85e-03 - Val loss: 1.44e-03
[Train] Iteration  50 - loss: 1.83e-03
[Train] Iteration 100 - loss: 1.84e-03
[Train] Iteration 150 - loss: 1.83e-03
Epoch  75 finished in 32.88s - Train loss: 1.83e-03 - Val loss: 1.44e-03
[Train] Iteration  50 - loss: 1.85e-03
[Train] Iteration 100 - loss: 1.83e-03
[Train] Iteration 150 - loss: 1.83e-03
Epoch  76 finished in 32.99s - Train loss: 1.83e-03 - Val loss: 1.55e-03
[Train] Iteration  50 - loss: 1.88e-03
[Train] Iteration 100 - loss: 1.86e-03
[Train] Iteration 150 - loss: 1.84e-03
Epoch  77 finished in 32.85s - Train loss: 1.84e-03 - Val loss: 1.44e-03
[Train] Iteration  50 - loss: 1.81e-03
[Train] Iteration 100 - loss: 1.81e-03
[Train] Iteration 150 - loss: 1.80e-03
Epoch  78 finished in 32.80s - Train loss: 1.80e-03 - Val loss: 1.42e-03
[Train] Iteration  50 - loss: 1.83e-03
[Train] Iteration 100 - loss: 1.81e-03
[Train] Iteration 150 - loss: 1.81e-03
Epoch  79 finished in 32.85s - Train loss: 1.81e-03 - Val loss: 1.44e-03
[Train] Iteration  50 - loss: 1.78e-03
[Train] Iteration 100 - loss: 1.78e-03
[Train] Iteration 150 - loss: 1.78e-03
Epoch  80 finished in 32.87s - Train loss: 1.78e-03 - Val loss: 1.48e-03
[Train] Iteration  50 - loss: 1.79e-03
[Train] Iteration 100 - loss: 1.78e-03
[Train] Iteration 150 - loss: 1.78e-03
Epoch  81 finished in 32.85s - Train loss: 1.77e-03 - Val loss: 1.40e-03
[Train] Iteration  50 - loss: 1.73e-03
[Train] Iteration 100 - loss: 1.75e-03
[Train] Iteration 150 - loss: 1.75e-03
Epoch  82 finished in 32.89s - Train loss: 1.74e-03 - Val loss: 1.38e-03
[Train] Iteration  50 - loss: 1.74e-03
[Train] Iteration 100 - loss: 1.73e-03
[Train] Iteration 150 - loss: 1.74e-03
Epoch  83 finished in 32.81s - Train loss: 1.74e-03 - Val loss: 1.38e-03
[Train] Iteration  50 - loss: 1.71e-03
[Train] Iteration 100 - loss: 1.74e-03
[Train] Iteration 150 - loss: 1.74e-03
Epoch  84 finished in 32.83s - Train loss: 1.72e-03 - Val loss: 1.40e-03
[Train] Iteration  50 - loss: 1.71e-03
[Train] Iteration 100 - loss: 1.71e-03
[Train] Iteration 150 - loss: 1.71e-03
Epoch  85 finished in 32.82s - Train loss: 1.70e-03 - Val loss: 1.45e-03
[Train] Iteration  50 - loss: 1.71e-03
[Train] Iteration 100 - loss: 1.73e-03
[Train] Iteration 150 - loss: 1.72e-03
Epoch  86 finished in 32.80s - Train loss: 1.72e-03 - Val loss: 1.35e-03
[Train] Iteration  50 - loss: 1.69e-03
[Train] Iteration 100 - loss: 1.68e-03
[Train] Iteration 150 - loss: 1.67e-03
Epoch  87 finished in 32.91s - Train loss: 1.68e-03 - Val loss: 1.38e-03
[Train] Iteration  50 - loss: 1.67e-03
[Train] Iteration 100 - loss: 1.70e-03
[Train] Iteration 150 - loss: 1.69e-03
Epoch  88 finished in 32.83s - Train loss: 1.69e-03 - Val loss: 1.32e-03
[Train] Iteration  50 - loss: 1.66e-03
[Train] Iteration 100 - loss: 1.66e-03
[Train] Iteration 150 - loss: 1.66e-03
Epoch  89 finished in 32.81s - Train loss: 1.67e-03 - Val loss: 1.33e-03
[Train] Iteration  50 - loss: 1.64e-03
[Train] Iteration 100 - loss: 1.66e-03
[Train] Iteration 150 - loss: 1.67e-03
Epoch  90 finished in 32.79s - Train loss: 1.66e-03 - Val loss: 1.36e-03
[Train] Iteration  50 - loss: 1.66e-03
[Train] Iteration 100 - loss: 1.65e-03
[Train] Iteration 150 - loss: 1.64e-03
Epoch  91 finished in 32.87s - Train loss: 1.64e-03 - Val loss: 1.43e-03
[Train] Iteration  50 - loss: 1.63e-03
[Train] Iteration 100 - loss: 1.63e-03
[Train] Iteration 150 - loss: 1.63e-03
Epoch  92 finished in 32.90s - Train loss: 1.63e-03 - Val loss: 1.33e-03
[Train] Iteration  50 - loss: 1.64e-03
[Train] Iteration 100 - loss: 1.63e-03
[Train] Iteration 150 - loss: 1.63e-03
Epoch  93 finished in 32.88s - Train loss: 1.63e-03 - Val loss: 1.33e-03
[Train] Iteration  50 - loss: 1.64e-03
[Train] Iteration 100 - loss: 1.62e-03
[Train] Iteration 150 - loss: 1.62e-03
Epoch  94 finished in 32.89s - Train loss: 1.62e-03 - Val loss: 1.26e-03
[Train] Iteration  50 - loss: 1.58e-03
[Train] Iteration 100 - loss: 1.60e-03
[Train] Iteration 150 - loss: 1.60e-03
Epoch  95 finished in 32.86s - Train loss: 1.60e-03 - Val loss: 1.31e-03
[Train] Iteration  50 - loss: 1.61e-03
[Train] Iteration 100 - loss: 1.61e-03
[Train] Iteration 150 - loss: 1.61e-03
Epoch  96 finished in 32.88s - Train loss: 1.60e-03 - Val loss: 1.29e-03
[Train] Iteration  50 - loss: 1.59e-03
[Train] Iteration 100 - loss: 1.59e-03
[Train] Iteration 150 - loss: 1.59e-03
Epoch  97 finished in 32.88s - Train loss: 1.60e-03 - Val loss: 1.21e-03
[Train] Iteration  50 - loss: 1.56e-03
[Train] Iteration 100 - loss: 1.55e-03
[Train] Iteration 150 - loss: 1.56e-03
Epoch  98 finished in 32.77s - Train loss: 1.57e-03 - Val loss: 1.25e-03
[Train] Iteration  50 - loss: 1.56e-03
[Train] Iteration 100 - loss: 1.56e-03
[Train] Iteration 150 - loss: 1.57e-03
Epoch  99 finished in 32.85s - Train loss: 1.57e-03 - Val loss: 1.24e-03
[Train] Iteration  50 - loss: 1.55e-03
[Train] Iteration 100 - loss: 1.56e-03
[Train] Iteration 150 - loss: 1.55e-03
Epoch 100 finished in 32.95s - Train loss: 1.56e-03 - Val loss: 1.39e-03
[Train] Iteration  50 - loss: 1.59e-03
[Train] Iteration 100 - loss: 1.58e-03
[Train] Iteration 150 - loss: 1.58e-03
Epoch 101 finished in 32.86s - Train loss: 1.58e-03 - Val loss: 1.23e-03
[Train] Iteration  50 - loss: 1.55e-03
[Train] Iteration 100 - loss: 1.56e-03
[Train] Iteration 150 - loss: 1.56e-03
Epoch 102 finished in 32.90s - Train loss: 1.56e-03 - Val loss: 1.23e-03
[Train] Iteration  50 - loss: 1.51e-03
[Train] Iteration 100 - loss: 1.52e-03
[Train] Iteration 150 - loss: 1.53e-03
Epoch 103 finished in 32.93s - Train loss: 1.52e-03 - Val loss: 1.21e-03
[Train] Iteration  50 - loss: 1.52e-03
[Train] Iteration 100 - loss: 1.52e-03
[Train] Iteration 150 - loss: 1.52e-03
Epoch 104 finished in 32.85s - Train loss: 1.51e-03 - Val loss: 1.23e-03
[Train] Iteration  50 - loss: 1.53e-03
[Train] Iteration 100 - loss: 1.53e-03
[Train] Iteration 150 - loss: 1.53e-03
Epoch 105 finished in 32.87s - Train loss: 1.52e-03 - Val loss: 1.23e-03
[Train] Iteration  50 - loss: 1.48e-03
[Train] Iteration 100 - loss: 1.50e-03
[Train] Iteration 150 - loss: 1.50e-03
Epoch 106 finished in 32.78s - Train loss: 1.50e-03 - Val loss: 1.22e-03
[Train] Iteration  50 - loss: 1.51e-03
[Train] Iteration 100 - loss: 1.53e-03
[Train] Iteration 150 - loss: 1.52e-03
Epoch 107 finished in 32.85s - Train loss: 1.52e-03 - Val loss: 1.21e-03
[Train] Iteration  50 - loss: 1.48e-03
[Train] Iteration 100 - loss: 1.49e-03
[Train] Iteration 150 - loss: 1.50e-03
Epoch 108 finished in 32.86s - Train loss: 1.50e-03 - Val loss: 1.17e-03
[Train] Iteration  50 - loss: 1.50e-03
[Train] Iteration 100 - loss: 1.49e-03
[Train] Iteration 150 - loss: 1.50e-03
Epoch 109 finished in 32.79s - Train loss: 1.50e-03 - Val loss: 1.25e-03
[Train] Iteration  50 - loss: 1.49e-03
[Train] Iteration 100 - loss: 1.48e-03
[Train] Iteration 150 - loss: 1.50e-03
Epoch 110 finished in 32.86s - Train loss: 1.50e-03 - Val loss: 1.16e-03
[Train] Iteration  50 - loss: 1.47e-03
[Train] Iteration 100 - loss: 1.47e-03
[Train] Iteration 150 - loss: 1.47e-03
Epoch 111 finished in 32.85s - Train loss: 1.47e-03 - Val loss: 1.23e-03
[Train] Iteration  50 - loss: 1.49e-03
[Train] Iteration 100 - loss: 1.47e-03
[Train] Iteration 150 - loss: 1.48e-03
Epoch 112 finished in 32.85s - Train loss: 1.49e-03 - Val loss: 1.21e-03
[Train] Iteration  50 - loss: 1.46e-03
[Train] Iteration 100 - loss: 1.46e-03
[Train] Iteration 150 - loss: 1.45e-03
Epoch 113 finished in 32.78s - Train loss: 1.46e-03 - Val loss: 1.23e-03
[Train] Iteration  50 - loss: 1.46e-03
[Train] Iteration 100 - loss: 1.47e-03
[Train] Iteration 150 - loss: 1.46e-03
Epoch 114 finished in 32.89s - Train loss: 1.46e-03 - Val loss: 1.24e-03
[Train] Iteration  50 - loss: 1.44e-03
[Train] Iteration 100 - loss: 1.44e-03
[Train] Iteration 150 - loss: 1.44e-03
Epoch 115 finished in 32.86s - Train loss: 1.45e-03 - Val loss: 1.24e-03
[Train] Iteration  50 - loss: 1.43e-03
[Train] Iteration 100 - loss: 1.44e-03
[Train] Iteration 150 - loss: 1.44e-03
Epoch 116 finished in 32.76s - Train loss: 1.44e-03 - Val loss: 1.13e-03
[Train] Iteration  50 - loss: 1.43e-03
[Train] Iteration 100 - loss: 1.45e-03
[Train] Iteration 150 - loss: 1.44e-03
Epoch 117 finished in 32.84s - Train loss: 1.44e-03 - Val loss: 1.17e-03
[Train] Iteration  50 - loss: 1.44e-03
[Train] Iteration 100 - loss: 1.44e-03
[Train] Iteration 150 - loss: 1.44e-03
Epoch 118 finished in 32.85s - Train loss: 1.44e-03 - Val loss: 1.14e-03
[Train] Iteration  50 - loss: 1.42e-03
[Train] Iteration 100 - loss: 1.42e-03
[Train] Iteration 150 - loss: 1.42e-03
Epoch 119 finished in 32.83s - Train loss: 1.43e-03 - Val loss: 1.13e-03
[Train] Iteration  50 - loss: 1.40e-03
[Train] Iteration 100 - loss: 1.41e-03
[Train] Iteration 150 - loss: 1.43e-03
Epoch 120 finished in 32.89s - Train loss: 1.43e-03 - Val loss: 1.29e-03
[Train] Iteration  50 - loss: 1.42e-03
[Train] Iteration 100 - loss: 1.41e-03
[Train] Iteration 150 - loss: 1.41e-03
Epoch 121 finished in 32.81s - Train loss: 1.42e-03 - Val loss: 1.22e-03
[Train] Iteration  50 - loss: 1.41e-03
[Train] Iteration 100 - loss: 1.41e-03
[Train] Iteration 150 - loss: 1.42e-03
Epoch 122 finished in 32.82s - Train loss: 1.41e-03 - Val loss: 1.11e-03
[Train] Iteration  50 - loss: 1.42e-03
[Train] Iteration 100 - loss: 1.40e-03
[Train] Iteration 150 - loss: 1.39e-03
Epoch 123 finished in 32.80s - Train loss: 1.39e-03 - Val loss: 1.14e-03
[Train] Iteration  50 - loss: 1.38e-03
[Train] Iteration 100 - loss: 1.40e-03
[Train] Iteration 150 - loss: 1.40e-03
Epoch 124 finished in 32.77s - Train loss: 1.41e-03 - Val loss: 1.13e-03
[Train] Iteration  50 - loss: 1.38e-03
[Train] Iteration 100 - loss: 1.38e-03
[Train] Iteration 150 - loss: 1.40e-03
Epoch 125 finished in 32.78s - Train loss: 1.40e-03 - Val loss: 1.11e-03
[Train] Iteration  50 - loss: 1.39e-03
[Train] Iteration 100 - loss: 1.40e-03
[Train] Iteration 150 - loss: 1.41e-03
Epoch 126 finished in 32.83s - Train loss: 1.40e-03 - Val loss: 1.09e-03
[Train] Iteration  50 - loss: 1.38e-03
[Train] Iteration 100 - loss: 1.37e-03
[Train] Iteration 150 - loss: 1.38e-03
Epoch 127 finished in 32.78s - Train loss: 1.38e-03 - Val loss: 1.11e-03
[Train] Iteration  50 - loss: 1.38e-03
[Train] Iteration 100 - loss: 1.38e-03
[Train] Iteration 150 - loss: 1.38e-03
Epoch 128 finished in 32.84s - Train loss: 1.39e-03 - Val loss: 1.12e-03
[Train] Iteration  50 - loss: 1.37e-03
[Train] Iteration 100 - loss: 1.37e-03
[Train] Iteration 150 - loss: 1.37e-03
Epoch 129 finished in 32.90s - Train loss: 1.37e-03 - Val loss: 1.16e-03
[Train] Iteration  50 - loss: 1.39e-03
[Train] Iteration 100 - loss: 1.38e-03
[Train] Iteration 150 - loss: 1.37e-03
Epoch 130 finished in 32.81s - Train loss: 1.37e-03 - Val loss: 1.09e-03
[Train] Iteration  50 - loss: 1.36e-03
[Train] Iteration 100 - loss: 1.36e-03
[Train] Iteration 150 - loss: 1.37e-03
Epoch 131 finished in 32.84s - Train loss: 1.36e-03 - Val loss: 1.09e-03
[Train] Iteration  50 - loss: 1.34e-03
[Train] Iteration 100 - loss: 1.35e-03
[Train] Iteration 150 - loss: 1.35e-03
Epoch 132 finished in 32.78s - Train loss: 1.35e-03 - Val loss: 1.06e-03
[Train] Iteration  50 - loss: 1.33e-03
[Train] Iteration 100 - loss: 1.33e-03
[Train] Iteration 150 - loss: 1.34e-03
Epoch 133 finished in 32.81s - Train loss: 1.35e-03 - Val loss: 1.08e-03
[Train] Iteration  50 - loss: 1.35e-03
[Train] Iteration 100 - loss: 1.35e-03
[Train] Iteration 150 - loss: 1.35e-03
Epoch 134 finished in 32.85s - Train loss: 1.35e-03 - Val loss: 1.05e-03
[Train] Iteration  50 - loss: 1.34e-03
[Train] Iteration 100 - loss: 1.33e-03
[Train] Iteration 150 - loss: 1.33e-03
Epoch 135 finished in 32.84s - Train loss: 1.33e-03 - Val loss: 1.02e-03
[Train] Iteration  50 - loss: 1.35e-03
[Train] Iteration 100 - loss: 1.34e-03
[Train] Iteration 150 - loss: 1.34e-03
Epoch 136 finished in 32.81s - Train loss: 1.34e-03 - Val loss: 1.10e-03
[Train] Iteration  50 - loss: 1.32e-03
[Train] Iteration 100 - loss: 1.35e-03
[Train] Iteration 150 - loss: 1.35e-03
Epoch 137 finished in 32.86s - Train loss: 1.35e-03 - Val loss: 1.08e-03
[Train] Iteration  50 - loss: 1.35e-03
[Train] Iteration 100 - loss: 1.35e-03
[Train] Iteration 150 - loss: 1.33e-03
Epoch 138 finished in 32.82s - Train loss: 1.33e-03 - Val loss: 1.04e-03
[Train] Iteration  50 - loss: 1.34e-03
[Train] Iteration 100 - loss: 1.34e-03
[Train] Iteration 150 - loss: 1.33e-03
Epoch 139 finished in 32.86s - Train loss: 1.33e-03 - Val loss: 1.02e-03
[Train] Iteration  50 - loss: 1.31e-03
[Train] Iteration 100 - loss: 1.31e-03
[Train] Iteration 150 - loss: 1.32e-03
Epoch 140 finished in 32.87s - Train loss: 1.32e-03 - Val loss: 1.05e-03
[Train] Iteration  50 - loss: 1.30e-03
[Train] Iteration 100 - loss: 1.31e-03
[Train] Iteration 150 - loss: 1.31e-03
Epoch 141 finished in 32.87s - Train loss: 1.32e-03 - Val loss: 1.09e-03
[Train] Iteration  50 - loss: 1.34e-03
[Train] Iteration 100 - loss: 1.32e-03
[Train] Iteration 150 - loss: 1.32e-03
Epoch 142 finished in 32.86s - Train loss: 1.31e-03 - Val loss: 1.07e-03
[Train] Iteration  50 - loss: 1.30e-03
[Train] Iteration 100 - loss: 1.30e-03
[Train] Iteration 150 - loss: 1.30e-03
Epoch 143 finished in 32.90s - Train loss: 1.30e-03 - Val loss: 1.08e-03
[Train] Iteration  50 - loss: 1.32e-03
[Train] Iteration 100 - loss: 1.32e-03
[Train] Iteration 150 - loss: 1.32e-03
Epoch 144 finished in 32.90s - Train loss: 1.31e-03 - Val loss: 1.13e-03
[Train] Iteration  50 - loss: 1.29e-03
[Train] Iteration 100 - loss: 1.29e-03
[Train] Iteration 150 - loss: 1.31e-03
Epoch 145 finished in 32.82s - Train loss: 1.31e-03 - Val loss: 1.04e-03
[Train] Iteration  50 - loss: 1.27e-03
[Train] Iteration 100 - loss: 1.26e-03
[Train] Iteration 150 - loss: 1.28e-03
Epoch 146 finished in 32.79s - Train loss: 1.28e-03 - Val loss: 1.01e-03
[Train] Iteration  50 - loss: 1.29e-03
[Train] Iteration 100 - loss: 1.29e-03
[Train] Iteration 150 - loss: 1.29e-03
Epoch 147 finished in 32.87s - Train loss: 1.29e-03 - Val loss: 1.01e-03
[Train] Iteration  50 - loss: 1.27e-03
[Train] Iteration 100 - loss: 1.28e-03
[Train] Iteration 150 - loss: 1.29e-03
Epoch 148 finished in 32.94s - Train loss: 1.29e-03 - Val loss: 1.02e-03
[Train] Iteration  50 - loss: 1.28e-03
[Train] Iteration 100 - loss: 1.29e-03
[Train] Iteration 150 - loss: 1.28e-03
Epoch 149 finished in 32.94s - Train loss: 1.28e-03 - Val loss: 9.93e-04
[Train] Iteration  50 - loss: 1.26e-03
[Train] Iteration 100 - loss: 1.27e-03
[Train] Iteration 150 - loss: 1.28e-03
Epoch 150 finished in 32.93s - Train loss: 1.28e-03 - Val loss: 1.04e-03
[Train] Iteration  50 - loss: 1.33e-03
[Train] Iteration 100 - loss: 1.30e-03
[Train] Iteration 150 - loss: 1.29e-03
Epoch 151 finished in 32.88s - Train loss: 1.28e-03 - Val loss: 1.01e-03
[Train] Iteration  50 - loss: 1.27e-03
[Train] Iteration 100 - loss: 1.27e-03
[Train] Iteration 150 - loss: 1.26e-03
Epoch 152 finished in 32.83s - Train loss: 1.26e-03 - Val loss: 1.08e-03
[Train] Iteration  50 - loss: 1.25e-03
[Train] Iteration 100 - loss: 1.25e-03
[Train] Iteration 150 - loss: 1.26e-03
Epoch 153 finished in 32.84s - Train loss: 1.26e-03 - Val loss: 1.05e-03
[Train] Iteration  50 - loss: 1.25e-03
[Train] Iteration 100 - loss: 1.25e-03
[Train] Iteration 150 - loss: 1.26e-03
Epoch 154 finished in 32.88s - Train loss: 1.25e-03 - Val loss: 1.02e-03
[Train] Iteration  50 - loss: 1.26e-03
[Train] Iteration 100 - loss: 1.25e-03
[Train] Iteration 150 - loss: 1.26e-03
Epoch 155 finished in 32.89s - Train loss: 1.26e-03 - Val loss: 9.93e-04
[Train] Iteration  50 - loss: 1.23e-03
[Train] Iteration 100 - loss: 1.24e-03
[Train] Iteration 150 - loss: 1.25e-03
Epoch 156 finished in 32.90s - Train loss: 1.25e-03 - Val loss: 1.05e-03
[Train] Iteration  50 - loss: 1.24e-03
[Train] Iteration 100 - loss: 1.24e-03
[Train] Iteration 150 - loss: 1.24e-03
Epoch 157 finished in 32.90s - Train loss: 1.24e-03 - Val loss: 1.03e-03
[Train] Iteration  50 - loss: 1.25e-03
[Train] Iteration 100 - loss: 1.24e-03
[Train] Iteration 150 - loss: 1.25e-03
Epoch 158 finished in 32.88s - Train loss: 1.25e-03 - Val loss: 1.03e-03
[Train] Iteration  50 - loss: 1.24e-03
[Train] Iteration 100 - loss: 1.24e-03
[Train] Iteration 150 - loss: 1.24e-03
Epoch 159 finished in 32.92s - Train loss: 1.23e-03 - Val loss: 9.66e-04
[Train] Iteration  50 - loss: 1.24e-03
[Train] Iteration 100 - loss: 1.23e-03
[Train] Iteration 150 - loss: 1.24e-03
Epoch 160 finished in 32.95s - Train loss: 1.24e-03 - Val loss: 1.07e-03
[Train] Iteration  50 - loss: 1.24e-03
[Train] Iteration 100 - loss: 1.24e-03
[Train] Iteration 150 - loss: 1.24e-03
Epoch 161 finished in 32.88s - Train loss: 1.25e-03 - Val loss: 1.01e-03
[Train] Iteration  50 - loss: 1.21e-03
[Train] Iteration 100 - loss: 1.21e-03
[Train] Iteration 150 - loss: 1.21e-03
Epoch 162 finished in 32.94s - Train loss: 1.22e-03 - Val loss: 9.87e-04
[Train] Iteration  50 - loss: 1.23e-03
[Train] Iteration 100 - loss: 1.24e-03
[Train] Iteration 150 - loss: 1.24e-03
Epoch 163 finished in 32.97s - Train loss: 1.23e-03 - Val loss: 1.09e-03
[Train] Iteration  50 - loss: 1.21e-03
[Train] Iteration 100 - loss: 1.22e-03
[Train] Iteration 150 - loss: 1.22e-03
Epoch 164 finished in 32.94s - Train loss: 1.23e-03 - Val loss: 9.63e-04
[Train] Iteration  50 - loss: 1.22e-03
[Train] Iteration 100 - loss: 1.22e-03
[Train] Iteration 150 - loss: 1.22e-03
Epoch 165 finished in 32.81s - Train loss: 1.21e-03 - Val loss: 9.74e-04
[Train] Iteration  50 - loss: 1.21e-03
[Train] Iteration 100 - loss: 1.22e-03
[Train] Iteration 150 - loss: 1.22e-03
Epoch 166 finished in 32.88s - Train loss: 1.22e-03 - Val loss: 9.78e-04
[Train] Iteration  50 - loss: 1.20e-03
[Train] Iteration 100 - loss: 1.21e-03
[Train] Iteration 150 - loss: 1.21e-03
Epoch 167 finished in 32.90s - Train loss: 1.22e-03 - Val loss: 1.01e-03
[Train] Iteration  50 - loss: 1.22e-03
[Train] Iteration 100 - loss: 1.21e-03
[Train] Iteration 150 - loss: 1.20e-03
Epoch 168 finished in 32.90s - Train loss: 1.21e-03 - Val loss: 9.48e-04
[Train] Iteration  50 - loss: 1.21e-03
[Train] Iteration 100 - loss: 1.21e-03
[Train] Iteration 150 - loss: 1.21e-03
Epoch 169 finished in 32.80s - Train loss: 1.20e-03 - Val loss: 9.25e-04
[Train] Iteration  50 - loss: 1.19e-03
[Train] Iteration 100 - loss: 1.19e-03
[Train] Iteration 150 - loss: 1.20e-03
Epoch 170 finished in 32.97s - Train loss: 1.20e-03 - Val loss: 9.62e-04
[Train] Iteration  50 - loss: 1.21e-03
[Train] Iteration 100 - loss: 1.20e-03
[Train] Iteration 150 - loss: 1.20e-03
Epoch 171 finished in 32.91s - Train loss: 1.20e-03 - Val loss: 9.52e-04
[Train] Iteration  50 - loss: 1.20e-03
[Train] Iteration 100 - loss: 1.20e-03
[Train] Iteration 150 - loss: 1.20e-03
Epoch 172 finished in 32.97s - Train loss: 1.19e-03 - Val loss: 9.29e-04
[Train] Iteration  50 - loss: 1.20e-03
[Train] Iteration 100 - loss: 1.20e-03
[Train] Iteration 150 - loss: 1.19e-03
Epoch 173 finished in 32.84s - Train loss: 1.20e-03 - Val loss: 9.61e-04
[Train] Iteration  50 - loss: 1.20e-03
[Train] Iteration 100 - loss: 1.20e-03
[Train] Iteration 150 - loss: 1.19e-03
Epoch 174 finished in 32.88s - Train loss: 1.19e-03 - Val loss: 9.34e-04
[Train] Iteration  50 - loss: 1.17e-03
[Train] Iteration 100 - loss: 1.19e-03
[Train] Iteration 150 - loss: 1.19e-03
Epoch 175 finished in 32.89s - Train loss: 1.19e-03 - Val loss: 9.39e-04
[Train] Iteration  50 - loss: 1.18e-03
[Train] Iteration 100 - loss: 1.20e-03
[Train] Iteration 150 - loss: 1.19e-03
Epoch 176 finished in 32.89s - Train loss: 1.20e-03 - Val loss: 9.60e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.17e-03
[Train] Iteration 150 - loss: 1.17e-03
Epoch 177 finished in 32.98s - Train loss: 1.17e-03 - Val loss: 9.90e-04
[Train] Iteration  50 - loss: 1.19e-03
[Train] Iteration 100 - loss: 1.19e-03
[Train] Iteration 150 - loss: 1.19e-03
Epoch 178 finished in 32.85s - Train loss: 1.18e-03 - Val loss: 9.14e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.16e-03
[Train] Iteration 150 - loss: 1.17e-03
Epoch 179 finished in 32.94s - Train loss: 1.17e-03 - Val loss: 9.18e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.17e-03
[Train] Iteration 150 - loss: 1.18e-03
Epoch 180 finished in 32.83s - Train loss: 1.18e-03 - Val loss: 9.62e-04
[Train] Iteration  50 - loss: 1.17e-03
[Train] Iteration 100 - loss: 1.16e-03
[Train] Iteration 150 - loss: 1.16e-03
Epoch 181 finished in 32.84s - Train loss: 1.17e-03 - Val loss: 9.65e-04
[Train] Iteration  50 - loss: 1.17e-03
[Train] Iteration 100 - loss: 1.17e-03
[Train] Iteration 150 - loss: 1.16e-03
Epoch 182 finished in 32.86s - Train loss: 1.17e-03 - Val loss: 9.85e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.17e-03
[Train] Iteration 150 - loss: 1.17e-03
Epoch 183 finished in 32.77s - Train loss: 1.17e-03 - Val loss: 9.56e-04
[Train] Iteration  50 - loss: 1.18e-03
[Train] Iteration 100 - loss: 1.19e-03
[Train] Iteration 150 - loss: 1.18e-03
Epoch 184 finished in 32.89s - Train loss: 1.18e-03 - Val loss: 9.34e-04
[Train] Iteration  50 - loss: 1.14e-03
[Train] Iteration 100 - loss: 1.15e-03
[Train] Iteration 150 - loss: 1.16e-03
Epoch 185 finished in 32.80s - Train loss: 1.16e-03 - Val loss: 9.37e-04
[Train] Iteration  50 - loss: 1.14e-03
[Train] Iteration 100 - loss: 1.15e-03
[Train] Iteration 150 - loss: 1.15e-03
Epoch 186 finished in 32.88s - Train loss: 1.16e-03 - Val loss: 9.31e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.15e-03
[Train] Iteration 150 - loss: 1.16e-03
Epoch 187 finished in 32.90s - Train loss: 1.16e-03 - Val loss: 9.18e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.15e-03
[Train] Iteration 150 - loss: 1.15e-03
Epoch 188 finished in 32.83s - Train loss: 1.15e-03 - Val loss: 9.06e-04
[Train] Iteration  50 - loss: 1.16e-03
[Train] Iteration 100 - loss: 1.15e-03
[Train] Iteration 150 - loss: 1.15e-03
Epoch 189 finished in 32.92s - Train loss: 1.15e-03 - Val loss: 9.16e-04
[Train] Iteration  50 - loss: 1.15e-03
[Train] Iteration 100 - loss: 1.15e-03
[Train] Iteration 150 - loss: 1.15e-03
Epoch 190 finished in 32.90s - Train loss: 1.15e-03 - Val loss: 9.18e-04
[Train] Iteration  50 - loss: 1.11e-03
[Train] Iteration 100 - loss: 1.12e-03
[Train] Iteration 150 - loss: 1.12e-03
Epoch 191 finished in 32.82s - Train loss: 1.13e-03 - Val loss: 9.07e-04
[Train] Iteration  50 - loss: 1.14e-03
[Train] Iteration 100 - loss: 1.14e-03
[Train] Iteration 150 - loss: 1.14e-03
Epoch 192 finished in 32.83s - Train loss: 1.14e-03 - Val loss: 9.29e-04
[Train] Iteration  50 - loss: 1.13e-03
[Train] Iteration 100 - loss: 1.13e-03
[Train] Iteration 150 - loss: 1.13e-03
Epoch 193 finished in 32.80s - Train loss: 1.13e-03 - Val loss: 9.19e-04
[Train] Iteration  50 - loss: 1.12e-03
[Train] Iteration 100 - loss: 1.12e-03
[Train] Iteration 150 - loss: 1.13e-03
Epoch 194 finished in 32.89s - Train loss: 1.13e-03 - Val loss: 9.63e-04
[Train] Iteration  50 - loss: 1.14e-03
[Train] Iteration 100 - loss: 1.14e-03
[Train] Iteration 150 - loss: 1.14e-03
Epoch 195 finished in 32.81s - Train loss: 1.13e-03 - Val loss: 8.93e-04
[Train] Iteration  50 - loss: 1.11e-03
[Train] Iteration 100 - loss: 1.12e-03
[Train] Iteration 150 - loss: 1.12e-03
Epoch 196 finished in 32.89s - Train loss: 1.13e-03 - Val loss: 1.05e-03
[Train] Iteration  50 - loss: 1.14e-03
[Train] Iteration 100 - loss: 1.13e-03
[Train] Iteration 150 - loss: 1.13e-03
Epoch 197 finished in 32.92s - Train loss: 1.13e-03 - Val loss: 8.90e-04
[Train] Iteration  50 - loss: 1.14e-03
[Train] Iteration 100 - loss: 1.13e-03
[Train] Iteration 150 - loss: 1.12e-03
Epoch 198 finished in 32.84s - Train loss: 1.13e-03 - Val loss: 9.25e-04
[Train] Iteration  50 - loss: 1.13e-03
[Train] Iteration 100 - loss: 1.12e-03
[Train] Iteration 150 - loss: 1.12e-03
Epoch 199 finished in 32.90s - Train loss: 1.12e-03 - Val loss: 8.92e-04
[Train] Iteration  50 - loss: 1.11e-03
[Train] Iteration 100 - loss: 1.11e-03
[Train] Iteration 150 - loss: 1.11e-03
Epoch 200 finished in 32.80s - Train loss: 1.11e-03 - Val loss: 1.04e-03
In [ ]:
fig = plt.figure(figsize=(7.5, 7))

# loss
ax = fig.add_subplot(111)
ax.set_title('Loss / Epoch')
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')

plt.plot(train_loss, label='Train', color='green', linewidth=3)
plt.plot(val_loss, label='Validation', color='red', linewidth=3)
plt.ylim((0, 2e-2))


plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x7f06607c91d0>
In [ ]:
predictions, test_loss = evaluate_autoencoder(autoencoder, images_test, batch_size=batch_size, return_predictions=True)
test_loss
[Valid] Iteration  50 - loss: 1.04e-03
Out[ ]:
0.0010420581728050654
In [ ]:
rows = 8
columns = 10
indices = np.random.choice(np.arange(len(images_test)), size=rows * columns // 2)
images = np.zeros((rows * columns // 2, 2, 28, 28))
images[:,0] = images_test[indices].reshape(rows * columns // 2, 28, 28)
images[:,1] = predictions[indices].reshape(rows * columns // 2, 28, 28)
images = images.reshape(rows * columns, 28, 28) * 255

texts = np.zeros((rows * columns // 2, 2), dtype=object)
texts[:,0] = np.array(['real' for _ in range(rows * columns // 2)])
texts[:,1] = np.array(['reconstructed' for _ in range(rows * columns // 2)])
texts = texts.reshape(rows * columns)

draw(images, texts, columns, rows)

Train the Conditional DCGAN for constructing right half of an image given the left half of it

In [ ]:
def generator_train_step(generator, discriminator, b_images, g_optimizer, criterion):
    g_optimizer.zero_grad()
    B = b_images.shape[0]

    z = torch.randn(B, 100, device='cuda')
    cuda_images = torch.tensor(b_images, device='cuda')

    left = cuda_images[:, :, :14]                                       # B 28  14
    right = cuda_images[:, :, 14:]                                      # B 28  14

    fake_right = generator(z, left)                                     # B 28  14
    validity = discriminator(torch.cat((left, fake_right), dim=2))      # B

    g_loss = criterion(validity, torch.ones(B, device='cuda'))
    g_loss.backward()
    g_optimizer.step()
    return float(g_loss)
In [ ]:
def discriminator_train_step(generator, discriminator, b_images, d_optimizer, criterion):
    d_optimizer.zero_grad()
    B = b_images.shape[0]

    real_images = torch.tensor(b_images[:B//2], device='cuda')                  # B/2   28  28
    fake_images = torch.tensor(b_images[B//2:], device='cuda')                  # B/2   28  28

    z = torch.randn(fake_images.shape[0], 100, device='cuda')
    left = fake_images[:, :, :14]
    with torch.no_grad():
        fake_right = generator(z, left).detach()                                # B/2   28  14
    generated_images = torch.cat((left, fake_right), dim=2)                     # B/2   28  28

    discriminator_inputs = torch.cat((real_images, generated_images), dim=0)    # B     28  28
    expected = torch.cat((torch.ones(real_images.shape[0], device='cuda'), torch.zeros(generated_images.shape[0], device='cuda')), dim=0)   # B

    validity = discriminator(discriminator_inputs)                              # B
    d_loss = criterion(validity, expected)

    d_loss.backward()
    d_optimizer.step()

    return float(d_loss)
In [ ]:
def train_GAN(generator, discriminator, images, g_optimizer, d_optimizer, criterion, batch_size, discriminator_step):
    epoch_g_loss = 0
    g_count = 0
    epoch_d_loss = 0
    d_count = 0
    
    generator.train()
    discriminator.train()

    for iter, b_images in get_batches(images, batch_size=batch_size):
        if (iter + 1) % discriminator_step == 0:
            d_loss = discriminator_train_step(generator, discriminator, b_images, d_optimizer, criterion)
            d_count += 1
            epoch_d_loss += d_loss
        else:
            g_loss = generator_train_step(generator, discriminator, b_images, g_optimizer, criterion)
            g_count += 1
            epoch_g_loss += g_loss
    
    epoch_g_loss /= g_count * 1.0
    epoch_d_loss /= d_count * 1.0

    return epoch_g_loss, epoch_d_loss
In [ ]:
def display_GAN(generator, images, columns, rows):
    with torch.no_grad():
        generator.eval()
        for _, batch in get_batches(images, batch_size=columns * rows):
            z = torch.randn(columns * rows, 100, device='cuda')
            batch_images = torch.tensor(batch, device='cuda')
            left = batch_images[:, :, :14]
            right = generator(z, left)
            generated_images = torch.cat((left, right), dim=2).cpu().numpy() * 255      # row*col   28  28

            images_to_show = np.zeros((rows * columns, 2, 28, 28))
            images_to_show[:,0] = batch
            images_to_show[:,1] = generated_images
            images_to_show = images_to_show.reshape(rows * columns * 2, 28, 28)

            texts = np.zeros((rows * columns, 2), dtype=object)
            texts[:,0] = np.array(['real' for _ in range(rows * columns)])
            texts[:,1] = np.array(['generated' for _ in range(rows * columns)])
            texts = texts.reshape(rows * columns * 2)

            draw(images_to_show, texts, columns * 2, rows)
            break
In [ ]:
imageconv = ImageConv()
imageconv.load_state_dict(torch.load('drive/MyDrive/mnist-autoencoder-conv-model.pt'))
Out[ ]:
<All keys matched successfully>
In [ ]:
generator = Generator(imageconv).cuda()
discriminator = Discriminator().cuda()
In [ ]:
discriminator.conv.load_state_dict(torch.load('drive/MyDrive/mnist-autoencoder-conv-model.pt'))
Out[ ]:
<All keys matched successfully>
In [ ]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
In [ ]:
batch_size = 256
epochs = 20
display_step = 1
discriminator_step = 6

g_loss = []
d_loss = []
generator_best_loss = float('inf')
discriminator_best_loss = float('inf')

for e in range(epochs):

    start_time = time()

    epoch_g_loss, epoch_d_loss = train_GAN(generator, discriminator, images_train, g_optimizer, d_optimizer, criterion, batch_size, discriminator_step)
    
    end_time = time()
    
    torch.save(discriminator.state_dict(), f'drive/My Drive/gan-discriminator-model-e{e:02d}.pt')
    torch.save(generator.state_dict(), f'drive/My Drive/gan-generator-model-e{e:02d}.pt')

    print(f'Epoch {e+1:3} finished in {end_time - start_time:.2f}s - Generator loss: {epoch_g_loss:.2e} - Discriminator loss: {epoch_d_loss:.2e}')

    g_loss.append(epoch_g_loss)
    d_loss.append(epoch_d_loss)

    if e % display_step == 0:
        display_GAN(generator, images_val, 5, 3)
Epoch   1 finished in 56.48s - Generator loss: 6.71e-01 - Discriminator loss: 7.45e-01
Epoch   2 finished in 57.21s - Generator loss: 7.24e-01 - Discriminator loss: 6.96e-01
Epoch   3 finished in 57.12s - Generator loss: 7.17e-01 - Discriminator loss: 6.92e-01
Epoch   4 finished in 57.07s - Generator loss: 8.80e-01 - Discriminator loss: 6.19e-01
Epoch   5 finished in 57.10s - Generator loss: 6.93e-01 - Discriminator loss: 7.11e-01
Epoch   6 finished in 57.07s - Generator loss: 6.92e-01 - Discriminator loss: 6.95e-01
Epoch   7 finished in 57.13s - Generator loss: 6.79e-01 - Discriminator loss: 6.65e-01
Epoch   8 finished in 57.10s - Generator loss: 7.23e-01 - Discriminator loss: 7.03e-01
Epoch   9 finished in 57.05s - Generator loss: 7.20e-01 - Discriminator loss: 6.96e-01
Epoch  10 finished in 57.16s - Generator loss: 7.05e-01 - Discriminator loss: 6.84e-01
Epoch  11 finished in 57.09s - Generator loss: 7.36e-01 - Discriminator loss: 6.73e-01
Epoch  12 finished in 57.04s - Generator loss: 8.17e-01 - Discriminator loss: 6.46e-01
Epoch  13 finished in 57.09s - Generator loss: 7.34e-01 - Discriminator loss: 7.09e-01
Epoch  14 finished in 57.12s - Generator loss: 7.96e-01 - Discriminator loss: 6.44e-01
Epoch  15 finished in 57.10s - Generator loss: 1.47e+00 - Discriminator loss: 3.37e-01
Epoch  16 finished in 57.08s - Generator loss: 9.52e-01 - Discriminator loss: 7.08e-01
Epoch  17 finished in 57.05s - Generator loss: 9.96e-01 - Discriminator loss: 5.89e-01
Epoch  18 finished in 57.09s - Generator loss: 1.13e+00 - Discriminator loss: 6.15e-01
Epoch  19 finished in 57.06s - Generator loss: 9.68e-01 - Discriminator loss: 4.88e-01
Epoch  20 finished in 57.10s - Generator loss: 1.26e+00 - Discriminator loss: 6.88e-02

Plot Generator/Discriminator losses

In [ ]:
fig = plt.figure(figsize=(7.5, 7))

# loss
ax = fig.add_subplot(111)
ax.set_title('Loss / Epoch')
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax.set_aspect('auto')

plt.plot(g_loss, label='Generator', color='green', linewidth=3)
plt.plot(d_loss, label='Discriminator', color='red', linewidth=3)

plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x7f05fe43e4a8>

Draw examples

Draw real vs. generated images below (from the best epoch).

In [ ]:
generator.load_state_dict(torch.load('drive/MyDrive/gan-generator-model-e09.pt'))
Out[ ]:
<All keys matched successfully>
In [ ]:
display_GAN(generator, images_test, 5, 20)