Implementation of SimCSE for unsupervised approach in Pytorch

Bhuvana Kundumani
6 min readNov 30, 2021

--

Hi Everyone ! Welcome to my blog. In this blog, I am going to show a simple implementation of SimCSE: Simple Contrastive Learning of Sentence Embeddings for the unsupervised approach. In SimCSE, the authors have used a simple contrastive learning framework to generate state-of-the-art sentence embeddings. In the supervised approach, natural language inference datasets are used, where the ‘entailment’ pairs are used as positives and ‘contradiction’ pairs are used as negatives. In this blog, we will be discussing in detail only about the implementation of the unsupervised approach. In the unsupervised approach, since we do not have labelled inputs, positive pairs of data are generated using dropout mechanism. Then contrastive objective is used to predict for a given input sentence. The code for this implementation is available at https://github.com/bhuvanakundumani/SimCSE_unsupervised.git

Sentence Embeddings : Sentence embeddings are nothing but representing sentences (i.e) pieces of text numerically as vectors. These sentence embeddings which map the sentence to a vector space are used for semantic search, clustering and semantic similarity comparisons. Basically the sentence embeddings of semantically similar sentences will be really close. In order to get better performance from our machine learning models, it is very important to generate superior sentence embeddings. The task of generating sentence embeddings that can capture the semantic data of the given text is a challenging task.

Before getting into the implementation of SimCSE for generating Sentence embeddings, let us see how we can generate sentence embeddings using BERT. In order to use BERT for generating sentence embeddings, we can take the embeddings generated from BERT and then average the output to derive the sentence embeddings;. Refer BERT_sentenceembedding.ipynb in the github repo for more details. However, the authors of Sentence-BERT:Sentence Embeddings using Siamese BERT-Networks ,show in their experiments that the embeddings averaged from BERT perform worse than averaging GloVe embeddings.

Now, Let us go into detail for implementing the unsupervised approach using contrastive learning in Pytorch.

Unsupervised approach:

In the unsupervised approach, contrastive learning is used where semantically similar sentences are pulled closely together while dissimilar sentences are pulled apart. Dropouts in the BERT architecture enables to produce two different representations with minimal variation for the same input sentence. Hence BERT’s dropout is used for generating positive pairs for the unlabelled data. Cross entropy loss is used as the loss function for the in-batch negatives.

Dataset :

For training we are going to use one million sentences from English Wikipedia. The data can be downloaded from here. Let us see the first few lines from the dataset.

Preprocessing Dataset:

Let us see the class wikiDataset which loads the English Wikipedia data as a csv file using Pandas and splits it into train and test. In our implementation, we are going to take the full data for training. So we will be setting the parameter full=True while creating the dataset.

In the unsupervised approach, our source_texts and target_texts are the same. When they are passed into the BERT encoder, due to dropout we get two different representations. The following screenshot shows the zipped source_texts and target_texts.

The function process_batch tokenizes the txt_list ( source texts and target texts zipped together) using the tokenizer and returns the input_ids, attention_mask and token_type_ids.

For our understanding, let us consider the inputs as shown in the img3 for a batch_size of 2. In the process_batch function, we separate the source texts (source_ls) and target texts (target_ls). We tokenize the source texts and target texts one at a time and store them in the input_ids, token_type_ids and attention_mask lists. Please note that the input_ids[0] contains the tokenized input for source text and input_ids[1] contains the input for target text. It is important to understand how the inputs are preprocessed and sent to the BERT model. img4 shows the input_ids for a batch_size of 2. Also note that the length of input_ids, token_type_ids and attention_mask will be twice of the batch_size.

Data Loader:

We use Pytorch’s DataLoader for loading the train data in the train_dataloader function.This function uses the RandomSampler to sample the data and preprocesses the input using the process_batch function.

Optimizer and Scheduler:

Similar to the paper, we use Adam optimizer with warmup steps and a linear learning rate scheduler.

Model:

In BertforCL class, we initialise the bert model and call the cl_init_function in the init function. We call the cl_forward in the forward function for training the SimCSE model for unsupervised data. So we will be discussing only cl_init and cl_forward function in detail in this blog.

Let us see cl_init function is the init function for contrastive learning. In cl_init function we can choose the pooler_type (‘cls’, ‘cls_before_pooler’, etc). In this blog, we will be discussing the ‘cls’ pooler_type, where we use the [CLS] representation along with BERT/RoBERTa’s MLP pooler. There is a simple MLP layer which is the head for getting sentence representations over RoBERTa/BERT’s CLS representation. The similarity function is used to calculate the cosine similarity between the embeddings. Please refer the repo for code for MLPLayer, Pooler and Similarity functions. They are very simple and straightforward to understand and hence not discussed in detail here.

https://gist.github.com/bhuvanakundumani/dc933afece8dcebdcf841a73b84666e1

For training, we call the cl_forward function in BertForcl, which we will discussing in detail. Let us look at the code snippet from cl_forward.

BERT model is passed to the encoder variable in cl_forward function. The inputs to the encoder are the input_ids, token_type_ids and attention_mask from the train_dataloader. The outputs of the encoder for the inputs contain the ‘last_hidden_state’ and ‘pooler_output’ for pooler_type==’’cls”. Based on the pooler type, the Pooler function returns the pooler_output. In our case since we are using ‘cls’, it returns last_hidden[:, 0] as the pooler_output with a shape torch.Size([16, 768]) .

As shown in img4, input_ids[0] has the ids for source texts and input_ids[1] has the ids for target texts.The pooler_output with the shape torch.Size([16, 768]) is reshaped to torch.Size([8, 2, 768]) using this code,

pooler_output = pooler_output.view((batch_size, args.num_sent, pooler_output.size(-1))).

Please note that [8, 0, 768] are the outputs for the source texts and [8, 1, 768] are the outputs for the target texts.

Since we are using ‘cls’ pooler type, an extra MLP layer (same as BERT implementation) is applied to the pooler_output. The output of the MLP layer is then separated into the source and target representation using the code,

z1, z2 = pooler_output[:,0], pooler_output[:,1]

Loss function and its parameters:

We calculate the cosine similarity between z1 and z2. Contrastive learning uses cross entropy loss as the loss function. We calculate the cross entropy loss for the cosine similarity scores and the labels using the code snippet shown below.

cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))labels = torch.arange(cos_sim.size(0)).long().to(cls.device)loss_fct = nn.CrossEntropyLoss()loss = loss_fct(cos_sim, labels)

Train Loop:

The train loop is shown below. We use the train_dataloader to load the data. The processed inputs are provided to the model. The loss from the model is back propagated and weights are updated using the Adam optimizer and scheduler.

The model thus trained is saved along with the tokenizer and config details in args.output_dir (specified in the Arguments class) as shown below:

model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
config.save_pretrained(args.output_dir)

The model thus trained can be used for deriving sentence embeddings for our use cases. Thanks for reading my blog !

References:

https://arxiv.org/pdf/2104.08821.pdf

https://towardsdatascience.com/understanding-contrastive-learning-d5b19fd96607

--

--

Bhuvana Kundumani
Bhuvana Kundumani

Responses (1)