Image Captioning: A Solution: A solution
Image captioning using CNNs for understanding the image, and LSTM for generating a caption
In this notebook, you should implement an exciting task, write a caption for images with an intelligent agent!
We use the COCO dataset for this purpose. COCO is large-scale object detection, segmentation, and captioning dataset. Also, we use the pycocotools library for some data-related works. So, you should install it first. Maybe it needs some dependencies that you have not on your PC. So, we recommend running this notebook on Google collab. You should upload data_related.py in the content folder on Colab if you want to do so.
!pip install pycocotools
!pip install nltk
import os
os.makedirs('opt' , exist_ok=True)
os.chdir( 'opt' )
!git clone 'https://github.com/cocodataset/cocoapi.git'
The following command imports some data-related functions, and it takes about 10 minutes for running.
import data_related
Your network should have two parts, a CNN for understanding the image and an LSTM for generating related sentences.
import torch
import torch.nn as nn
import torchvision.models as models
import math
class Encoder(nn.Module):
def __init__(self, embed_size):
super(Encoder, self).__init__()
# todo: Define a CNN with an extended fully-connected. Your output should be of the shape Batch_Size x embed_size.
# Make sure that your model is strong enough to encode the image properly.
#######################
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
param.requires_grad_(False)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
#######################
def forward(self, images):
features = None
#######################
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.embed(features)
#######################
return features
class Decoder(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
super(Decoder, self).__init__()
self.embed_size = embed_size
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_layers = num_layers
# todo: Define an embedding layer to transform inputs from "vocab_size" dim to "embed size" dim.
#######################
self.word_embedding = nn.Embedding( self.vocab_size , self.embed_size )
#######################
# todo: Define an LSTM decoder with input size, hidden size, and num layers specified in the input.
#######################
self.lstm = nn.LSTM( input_size = self.embed_size ,
hidden_size = self.hidden_size,
num_layers = self.num_layers ,
batch_first = True
)
#######################
# todo: Define a fully-connected layer to transform the output hidden size of LSTM to a "vocab_size" dim vector.
#######################
self.fc = nn.Linear( self.hidden_size , self.vocab_size )
#######################
def init_hidden(self, batch_size):
return ( torch.zeros( self.num_layers , batch_size , self.hidden_size ).to(device),
torch.zeros( self.num_layers , batch_size , self.hidden_size ).to(device) )
def forward(self, features, captions):
captions = captions[:, :-1]
self.batch_size = features.shape[0]
self.hidden = self.init_hidden( self.batch_size )
outputs = None
# todo: Compute the output of the model.
#######################
embeds = self.word_embedding( captions )
inputs = torch.cat( ( features.unsqueeze(dim=1) , embeds ) , dim =1 )
lstm_out , self.hidden = self.lstm(inputs , self.hidden)
outputs = self.fc( lstm_out )
#######################
return outputs
def generate(self, inputs, max_len=20):
final_output = []
batch_size = inputs.shape[0]
hidden = self.init_hidden(batch_size)
max_sent_length = 20
# todo: You should pass hidden state and previous vocab to LSTM successively, and stop generating when
# The length of the sentence exceeds max_sent_length, or EOS token (end of sentence, index 1) occurs.
# Just return indexes in final_output.
#######################
while True:
lstm_out, hidden = self.lstm(inputs, hidden)
outputs = self.fc(lstm_out)
outputs = outputs.squeeze(1)
_, max_idx = torch.max(outputs, dim=1)
final_output.append(max_idx.cpu().numpy()[0].item())
if (max_idx == 1 or len(final_output) >=20 ):
break
inputs = self.word_embedding(max_idx)
inputs = inputs.unsqueeze(1)
#######################
return final_output
embed_size = 256
hidden_size = 100
num_layers =1
num_epochs = 4
print_every = 150
save_every = 1
vocab_size = len(data_related.data_loader_train.dataset.vocab)
total_step = math.ceil(len(data_related.data_loader_train.dataset.caption_lengths) /
data_related.data_loader_train.batch_sampler.batch_size)
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_save_path = 'model_weights/'
os.makedirs( model_save_path , exist_ok=True)
encoder = Encoder(embed_size)
decoder = Decoder(embed_size, hidden_size, vocab_size ,num_layers)
# todo: Define loss function and optimizer for encoder and decoder weights.
#######################
criterion = nn.CrossEntropyLoss()
all_params = list(decoder.parameters()) + list( encoder.embed.parameters() )
optimizer = torch.optim.Adam( params = all_params , lr = lr )
#######################
The training process may take up to 10 hours. Save the model frequently. If the training process stops for some reason, continue from the last saved model.
import sys
encoder.train()
decoder.train()
for e in range(1, num_epochs):
for step in range(total_step):
indices = data_related.data_loader_train.dataset.get_train_indices()
new_sampler = data_related.data.sampler.SubsetRandomSampler(indices)
data_related.data_loader_train.batch_sampler.sampler = new_sampler
images,captions = next(iter(data_related.data_loader_train))
images , captions = images.to(device) , captions.to(device)
encoder , decoder = encoder.to(device) , decoder.to(device)
encoder.zero_grad()
decoder.zero_grad()
# todo: Compute output and loss.
#######################
features = encoder(images)
output = decoder( features , captions )
loss = criterion( output.view(-1, vocab_size) , captions.view(-1))
#######################
loss.backward()
optimizer.step()
stat_vals = 'Epochs [%d/%d] Step [%d/%d] Loss [%.4f] ' %(e+1, num_epochs, step, total_step,loss.item())
if step % print_every == 0:
print(stat_vals)
sys.stdout.flush()
torch.save( encoder.state_dict() , os.path.join( model_save_path , 'encoderdata_{}.pkl'.format(e+1) ) )
torch.save( decoder.state_dict() , os.path.join( model_save_path , 'decoderdata_{}.pkl'.format(e+1) ) )
if e % save_every == 0:
torch.save( encoder.state_dict() , os.path.join( model_save_path , 'encoderdata_{}.pkl'.format(e+1) ) )
torch.save( decoder.state_dict() , os.path.join( model_save_path , 'decoderdata_{}.pkl'.format(e+1) ) )
encoder.to(device)
decoder.to(device)
encoder.eval()
decoder.eval()
original_img , processed_img = next( data_related.data_iter )
features = encoder(processed_img.to(device)).unsqueeze(1)
final_output = decoder.generate( features , max_len=20)
data_related.get_sentences(original_img, final_output)