Image Captioning: A Solution: A solution

Image captioning using CNNs for understanding the image, and LSTM for generating a caption

In this notebook, you should implement an exciting task, write a caption for images with an intelligent agent!

We use the COCO dataset for this purpose. COCO is large-scale object detection, segmentation, and captioning dataset. Also, we use the pycocotools library for some data-related works. So, you should install it first. Maybe it needs some dependencies that you have not on your PC. So, we recommend running this notebook on Google collab. You should upload data_related.py in the content folder on Colab if you want to do so.

In [ ]:
!pip install pycocotools
!pip install nltk
In [ ]:
import os

os.makedirs('opt' , exist_ok=True)
os.chdir( 'opt' )
!git clone 'https://github.com/cocodataset/cocoapi.git'

The following command imports some data-related functions, and it takes about 10 minutes for running.

In [ ]:
import data_related

Your network should have two parts, a CNN for understanding the image and an LSTM for generating related sentences.

Model

In [ ]:
import torch
import torch.nn as nn
import torchvision.models as models
import math
In [ ]:
class Encoder(nn.Module):
    def __init__(self, embed_size):
        super(Encoder, self).__init__()
        # todo: Define a CNN with an extended fully-connected. Your output should be of the shape Batch_Size x embed_size.
        # Make sure that your model is strong enough to encode the image properly.
        #######################
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)
        
        #######################
    
    def forward(self, images):
        features = None
        #######################
        features = self.resnet(images)        
        features = features.view(features.size(0), -1)        
        features = self.embed(features)  

        #######################
        return features
In [ ]:
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size 
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        # todo: Define an embedding layer to transform inputs from "vocab_size" dim to "embed size" dim.
        #######################
        self.word_embedding = nn.Embedding( self.vocab_size , self.embed_size )
        
        
        #######################
        
        # todo: Define an LSTM decoder with input size, hidden size, and num layers specified in the input.  
        #######################
        self.lstm  = nn.LSTM(    input_size  =  self.embed_size , 
                             hidden_size = self.hidden_size,
                             num_layers  = self.num_layers ,
                             batch_first = True 
                             )
        
        #######################
        
        # todo: Define a fully-connected layer to transform the output hidden size of LSTM to a "vocab_size" dim vector.
        #######################
        
        
        self.fc = nn.Linear( self.hidden_size , self.vocab_size  )
        #######################
    def init_hidden(self, batch_size):
        return ( torch.zeros( self.num_layers , batch_size , self.hidden_size  ).to(device),
        torch.zeros( self.num_layers , batch_size , self.hidden_size  ).to(device) )
    
    def forward(self, features, captions):            
        captions = captions[:,  :-1]      
        self.batch_size = features.shape[0]
        self.hidden = self.init_hidden( self.batch_size )
        outputs = None  
        
        # todo: Compute the output of the model.
        #######################
        embeds = self.word_embedding( captions )
        inputs = torch.cat( ( features.unsqueeze(dim=1) , embeds ) , dim =1  )      
        lstm_out , self.hidden = self.lstm(inputs , self.hidden)      
        outputs = self.fc( lstm_out )
        
        
        #######################
        return outputs

    def generate(self, inputs, max_len=20):
        final_output = []
        batch_size = inputs.shape[0]         
        hidden = self.init_hidden(batch_size)
        max_sent_length = 20
    
        # todo: You should pass hidden state and previous vocab to LSTM successively, and stop generating when
        # The length of the sentence exceeds max_sent_length, or EOS token (end of sentence, index 1) occurs.
        # Just return indexes in final_output.
        #######################
        while True:
            lstm_out, hidden = self.lstm(inputs, hidden) 
            outputs = self.fc(lstm_out)  
            outputs = outputs.squeeze(1) 
            _, max_idx = torch.max(outputs, dim=1) 
            final_output.append(max_idx.cpu().numpy()[0].item())             
            if (max_idx == 1 or len(final_output) >=20 ):
                break
            
            inputs = self.word_embedding(max_idx) 
            inputs = inputs.unsqueeze(1)  
        
        
        
        #######################
        return final_output  

Train

In [ ]:
embed_size = 256
hidden_size = 100
num_layers =1 
num_epochs = 4
print_every = 150
save_every = 1
vocab_size = len(data_related.data_loader_train.dataset.vocab)
total_step = math.ceil(len(data_related.data_loader_train.dataset.caption_lengths) / 
                       data_related.data_loader_train.batch_sampler.batch_size)
lr = 0.001
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_save_path = 'model_weights/'
os.makedirs( model_save_path , exist_ok=True)

encoder = Encoder(embed_size)
decoder = Decoder(embed_size, hidden_size, vocab_size ,num_layers)
In [ ]:
# todo: Define loss function and optimizer for encoder and decoder weights.
#######################
criterion = nn.CrossEntropyLoss()
all_params = list(decoder.parameters())  + list( encoder.embed.parameters() )
optimizer = torch.optim.Adam( params  = all_params , lr = lr  )
#######################

The training process may take up to 10 hours. Save the model frequently. If the training process stops for some reason, continue from the last saved model.

In [ ]:
import sys
encoder.train()
decoder.train()
for e in range(1, num_epochs):
    for step in range(total_step):
        indices = data_related.data_loader_train.dataset.get_train_indices()
        new_sampler = data_related.data.sampler.SubsetRandomSampler(indices)
        data_related.data_loader_train.batch_sampler.sampler = new_sampler    
        images,captions = next(iter(data_related.data_loader_train))    
        images , captions = images.to(device) , captions.to(device)
        encoder , decoder = encoder.to(device) , decoder.to(device)
        encoder.zero_grad()    
        decoder.zero_grad()
        # todo: Compute output and loss.
        #######################
        features = encoder(images)
        output = decoder( features , captions )    
        loss = criterion( output.view(-1, vocab_size) , captions.view(-1))
        #######################
        loss.backward()
        optimizer.step()
        stat_vals = 'Epochs [%d/%d] Step [%d/%d] Loss [%.4f] ' %(e+1, num_epochs, step, total_step,loss.item())
        if step % print_every == 0:
            print(stat_vals)
            sys.stdout.flush()
            torch.save( encoder.state_dict() ,  os.path.join( model_save_path , 'encoderdata_{}.pkl'.format(e+1) ) )
            torch.save( decoder.state_dict() ,  os.path.join( model_save_path , 'decoderdata_{}.pkl'.format(e+1) ) )
    if e % save_every == 0:
        torch.save( encoder.state_dict() ,  os.path.join( model_save_path , 'encoderdata_{}.pkl'.format(e+1) ) )
        torch.save( decoder.state_dict() ,  os.path.join( model_save_path , 'decoderdata_{}.pkl'.format(e+1) ) )

Test

In [ ]:
encoder.to(device) 
decoder.to(device)
encoder.eval()
decoder.eval()
original_img , processed_img  = next( data_related.data_iter )

features  = encoder(processed_img.to(device)).unsqueeze(1)
final_output = decoder.generate( features  , max_len=20)
data_related.get_sentences(original_img, final_output)