:capital_abcd: Modelling Context in Word Embeddings

Research project on word embeddings for text classification

A word embedding is a paramaterized function which maps words in some language to high-dimensional vectors. Converting words to such vector embeddings before passing them into deep neural networks has proved to be a highly effective technique for text classification tasks and is, according to me, one of the most fascinating topics in NLP research.

Presented here is a method to modify the word embeddings of a word in a sentence with its surrounding context using a bidirectional Recurrent Neural Network (RNN). The hypothesis is that these modified embeddings are a better input for performing text classification tasks like sentiment analysis or polarity detection.

The intuitive explaination

Given the embeddings for all the words in a sentence like ‘the quick brown fox jumps over the lazy dog’, the proposed model modifies the existing embedding for ‘fox’ to incorporate information about it being ‘quick’ and ‘brown’, and the fact that it ‘jumps over’ the ‘dog’ (whose updated embedding now reflect that its ‘lazy’ and got ‘jumped over’ by the ‘fox’).

Applied in combination with pre-trained word embeddings (like word2vec or GloVe) which encode global syntactic and semantic information about words such as ‘fox’ and ‘dog’, the method adds local context to these embeddings based on surrounding words. The new embeddings can then be fed into pipelines for text classification or be trained further in an end-to-end fashion.

Bi-directional LSTM model

Bidirectional RNN layer

Given the word embeddings for each word in a sentence/sequence of words, the sequence can be represented as a 2-D tensor of shape (seq_len, embedding_dim). The following steps can be performed to add infomation about the surrounding words to each embedding-

  1. Pass the embedding of each word sequentially into a forward-directional RNN (fRNN). For each sequential timestep, we obtain the hidden state of the fRNN, a tensor of shape (hidden_size). The hidden state encodes information about the current word and all the words previously encountered in the sequence. Our final output from the fRNN is a 2-D tensor of shape (seq_len, hidden_size).

  2. Pass the embedding of each word sequentially (after reversing the sequence of words) into a backward-directional RNN (bRNN). For each sequential timestep, we again obtain the hidden state of the bRNN, a tensor of shape (hidden_size). The hidden state encodes information about the current word and all the words previously encountered in the sequence. Our output is a 2-D tensor of shape (seq_len, hidden_size). This output is reversed again to obtain the final output of the bRNN.

  3. Concatenate the fRNN and bRNN outputs element-wise for each of the seq_len timesteps in the two outputs. The final output is another 2-D tensor of shape (seq_len, hidden_size).

The fRNN and bRNN together form a bidirectional RNN. The difference between the final outputs of fRNN and bRNN is that at each timestep they are encoding information about two different sub-sequences (which are formed by splitting the sequence at the word at that timestep). Concatenating these outputs at each timestep results in a tensor encoding information about the word at that timestep and all the words in the sequence to its left and right.

The cells used in the RNNs are the Long Short-term Memory (LSTM) cells, which are better at capturing long-term dependencies than vanilla RNN cells. This ensures our model doesn’t just consider the nearest neighbours while modifying a word’s embedding.


This repository contains the code implementing the proposed model as a pre-processing layer before feeding it into a Convolutional Neural Network for Sentence Classification (Kim, 2014). Training happens end-to-end in a supervised manner: the RNN layer is simply inserted as part of the existing model’s architecture for text classification.

The code is built on top of Denny Britz’s implementation of Kim’s CNN, and also allows loading pre-trained word2vec embeddings. For details and instructions on usage, please visit the repository.

Experiments on text classification

The following models were trained and evaluated on the the IMDb Movie Reviews Dataset by UMontreal (for detecting the polarity of reviews)-

  1. Yoon Kim’s baseline CNN model without the RNN layer, embedding_dim = 128, num_filters = 128 [ORANGE]
  2. The proposed model, embedding_dim = 128, rnn_hidden_size = 128, num_filters = 128 [PURPLE]
  3. The proposed model with more capacity, embedding_dim = 300, rnn_hidden_size = 300, num_filters = 150 [BLUE]

All models were trained with the following hyperparameters using the Adam optimizer: num_epochs = 100, batch_size = 32, learning_rate = 0.001. Ten percent of the data was held out for validation.


Training accuracy
Training loss

It’s clear that training converges for all three models before 100 epochs.

Validation accuracy
Validation loss

Ideas and Next Steps

A note on novelty…

I had this idea when I had just started studying about Deep Learning and NLP. I started working on the code in July 2016 and learnt much later that the technique I had come up with was very popular in machine translation using neural networks. Its a part of the architecture of Google Translate’s new Zero-Shot Multilingual Translation system (see Section 3.2 of this paper) and had first been proposed by Bahdanou et al in 2014 (see Section 3).

Quoting the Bahdanou et al paper-

We obtain an annotation for each word by concatenating the forward and backward states of a bidirectional RNN. The annotation contains the summaries of both the preceding words and the following words. Due to the tendency of RNNs to better represent recent inputs, the annotation will contain information about the whole input sequence with a strong focus on the parts surrounding the anotated word.

Google Translate is probably the best example of an amazing product which is powered by neural networks, and I’m happy I was thinking along the right path.

Chaitanya K. Joshi

Research Engineer at A*STAR, Singapore

rss facebook twitter github youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora quora googlescholar