Jaewoo Song
Jaewoo Song


  • Tech

(This post was modified in December 2nd after the re-implementation of the project to prevent you from confused by the difference between the repository and the contents of the post.)

This project is constructing the multi-turn open-domain dialogue generation model by fine-tuning the pre-trained Generative Pre-training 2 (GPT-2)[1].

In the last post, we found that there are several limitations in the results from the Relevant Contexts with Self-attention (ReCoSa).

This time, I expect better outputs since the GPT-2 is applied, which is well-trained with Language Modeling tasks.

Let’s start.

LM Head vs Double Heads

In the introductions, I introduced the fine-tuning method which the Huggingface team applied[2].

They fine-tuned the GPT-2 by training not only with the original language modeling (LM) task but also with the binary classification, which determines whether the given response is a proper one or not, as the multi-task learning.

The description of this configuration is as follows.

The description of Huggingface's transfer learning structure using GPT-2 for ConvAI2.

As we can see, the model takes two inputs, the golden reply and the distractor, and classifies which one is a correct target.

With this multi-task learning setting, the model learns not only how to generate the answer but also how to make the proper response with the relevant topic by considering dialogue contexts.

I adopted this method at first, but after an experiment, I gave up the classification task and changed it into language modeling only.

The reason why I changed my mind is as follows.

  1. Increase in training time

    Including the additional distractor leads to the bigger size of one batch. As a result, I could not make the batch size sufficiently large. Even though I introduced only one distractor, I had to set the batch size into $2$ in my resource environment and it took about $32$ hours to conduct one epoch.

  2. Less meaningful classification training

    The Huggingface team used PersonaChat data and extracted each distractor from the candidates which are included in the dataset itself. But in my case, I used various datasets combined and it was difficult to make these additional candidate sets with them. So I randomly sampled an utterance from entirely other dialogues and set it as a distractor. But I noticed that the loss for multi-choice classification had hardly decreased during the training. In my opinion, most distractors sampled are generic and this means that many context + distractor pairs can quite make sense without a serious problem. Of course, I could search for another solution for this, but due to the reason #$1$ I mentioned before, I stopped.

So I decided to do this with the GPT-2 with a sinle LM Head, not the one with two heads, to fine-tune it focusing only on the response generation task.

Data pre-processing

Next, let’s talk about data processing.

Unlike the case in the ReCoSa structure last time, in GPT-2 method the entire dialogue histories are concatenated and given to the model to make it generate the proper response.

This is because, as I stated before, GPT-2 is a model which was pre-trained to conduct the unidirectional language modeling with the transformer’s decoder layers.

That is, this fine-tuning approach is obvious in that the model considers the overall contexts and generates a reply through the next word prediction, which is quite the same as the objective of GPT-2’s pre-training.

First by referring to the Huggingface’s idea, I added $3$ special tokens which is not included in the original GPT-2 vocabulary.

They help the model notice the beginning of the sequences and differentiate each speaker’s utterance.

These are the special token I added.

  • bos: "<bos>" (the beginning of sentence token)
  • speaker1: "<speaker1>" (the first speaker token)
  • speaker2: "<speaker2>" (the second speaker token)

Since the GPT-2 has already the end token, so I did not add an additional one.

Additionally, the pad token is not included either, since the model only attends to the words located on the left of the current position.

In other words, as these padded positions are not considered by the model and they are not going to affect the result, they are allowed to be any token.

The Huggingface’s Transformers library supports users to add new tokens into the vocabulary in the tokenizer and increase the size of the embedding layer in the model accordingly.

If the increased size is larger than the original vocabulary size, then additional vectors with initialized values fill the last rows of the embedding lookup table.

This can be easily implemented as follows.

from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Tokenizer & Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# The dictionary for defining special tokens
special_tokens = {
    'bos_token': "<bos>",
    'additional_special_tokens': ["<speaker1>", "<speaker2>"]

num_new_tokens = tokenizer.add_special_tokens(special_tokens)
vocab = tokenizer.get_vocab()

Next, we’re gonna look at the compositions of inputs and outputs which are included in each batch.

There are $3$ components, input_ids, token_type_ids, and labels.

  • input_ids: This is the main input and consists of token ids. When training, this becomes the entire sequence including the response concatenated at the back of the context. At the inference phase, only the system speaker token is concatenated after the dialogue history so far and each token generated keeps attached until the end token appears.
  • token_type_ids: This is the additional input which specifies each segment’s speaker in input_ids. This differentiates each time step’s utterance and only comprises the id of speaker1 and speaker2 tokens.
  • labels: This is the actual golden reply to be generated. It can be built by masking all positions into mask value, $-100$, except the response parts. You might think that this should be shifted right, but it is not necessary since models in the Transformers library process make this label shifted to calculate the LM loss. Obviously, we don’t have this when inferencing.

You can see the details in below description.

The details of data composition in GPT-2 fine-tuning for multi-turn dialogue generation.

But one thing we should think about is the maximum length.

The maximum length the model can take is limited (for the GPT-2, the default length is 1024.) and the inputs are produced by all utterances are put together, so in some cases we cannot include all histories in specified time steps depending on the length of each utterance.

There might be several ways to handle this issue, and I chose to exclude the total excessive utterances if the length of an input sequence is longer than the maximum length, after including as many turns as possible within the pre-defined maximum number of turns.

In addition, I defined a collate function to make paddings at the time of each batch loading, which makes the lengths of the rest of the inputs in a batch the same as the length of the longest one.

This can be implemented with PyTorch’s torch.nn.utils.rnn.pad_sequence function by providing the pad token we want as an argument. We can reduce the unnecessary memory usage with this, since the padding is created dynamically according to each batch.

The details of data processing codes are as follows.

from itertools import chain

input_ids = []  # (N, L)
token_type_ids = []  # (N, L)
labels = []  # (N, L)

# The list "dials" is a list of dialogues which is lists containing tokenized utterances. 
for dial in dials:
	hists = []
	for u, utter in enumerate(dial):
        if u % 2 == 0:
            hists.append([sp1_id] + utter)  # Speaker 1: User
            hists.append([sp2_id] + utter)  # Speaker 2: System

        for h in range(len(hists)):
        	if hists[h][0] == sp2_id:
                start = max(0, h-max_turns+1)
                for s in range(start, h):
                	contexts = hists[s:h+1]
                    input_ids = [bos_id] + list(chain.from_iterable(contexts)) + [eos_id]

	                if len(input_ids) <= max_len:
                        start_sp_id, next_sp_id = contexts[0][0], contexts[1][0]
                        token_type_ids = [
                        	[start_sp_id] * len(ctx) if c % 2 == 0 \ 
                            else [next_sp_id] * len(ctx) \
                            for c, ctx in enumerate(contexts)
                        assert token_type_ids[-1][0] == sp2_id
                        token_type_ids = \
                        	[start_sp_id] + \ 
                            list(chain.from_iterable(token_type_ids)) + \
                    	assert len(input_ids) == len(token_type_ids)

                        labels = [
                        	[-100] * len(ctx) if c < len(contexts)-1 \
                            else [-100] + ctx[1:] \
                            for c, ctx in enumerate(contexts)
                        assert labels[-1][1:] == contexts[-1][1:]
                        labels = [-100] + list(chain.from_iterable(labels)) + [eos_id]
                        assert len(input_ids) == len(labels)



Also, the pad collate function can be implemented like this.

class PadCollate():
    def __init__(self, eos_id):
        self.eos_id = eos_id
    def pad_collate(self, batch):
        input_ids, token_type_ids, labels =[], [], []
        for idx, seqs in enumerate(batch):
        input_ids = torch.nn.utils.rnn.pad_sequence(
          input_ids, batch_first=True, padding_value=self.eos_id
        token_type_ids = torch.nn.utils.rnn.pad_sequence(
          token_type_ids, batch_first=True, padding_value=self.eos_id
        labels = torch.nn.utils.rnn.pad_sequence(
          labels, batch_first=True, padding_value=-100
        return input_ids, token_type_ids, labels


Actually, there is nothing difficult in training.

It is not that different from previous implementations, which we should just put inputs to GPT-2 LM Head model after pre-processing the data properly as mentioned above.

One thing I added this time is “perplexity” as well as train/validation losses to evaluate the model during training.

The calculation is simple, which can be obtained easily by implementing an exponential function to the loss as the exponent.

Thinking about the formula of the perplexity, it is quite obvious.

We checked it before, but let’s see the definition and the formula of perplexity again.

The perplexity is an evaluation method for LM which indicates how the model chooses the next tokens with high probabilities.

This is calculated by normalizing the reciprocal of the joint probability, where each current sequence will appear, to the length of the sequence.

\[PPL = \sqrt[n]{\frac{1}{P(w_1, w_2, ... , w_n)}} = \sqrt[n]{\frac{1}{\prod_{i=1}^{N}P(w_i \mid w_1, w_2, ... ,w_n)}}\]

As we can see, the higher the probability is, the lower the perplexity becomes, which means that the LM performance is more decent.

Then what is the relation between the perplexity and the loss function?

We can easily induce the process considering the loss function we normally use in the next word prediction task is the cross entropy loss.

By putting the input sequences and the labels, we can get the negative log loss normalized to the sequence length.

And the value inside this negative log is the joint probability which has already passed through the softmax function.

So by making this as the exponent, the procedure becomes as follows.


Therefore, with torch.exp() function, we can get the perplexity.

When training, the inputs put into the model are input_ids, token_type_ids, and labels.

The GPT-2 LM Head Model gives an output tuple which contains the loss at $0$th position and the actual result logits tensor at its $1$st index.

I trained the model for $10$ epochs, and used the Tensorboard to record the loss and perplexity value after each epoch finished.

The train perplexities were also recorded, but some values are too high, since the model was not tuned at first, so I just present the validation perplexity changes here.

As I mentioned, underfitting of the model or relatively high perplexity of a certain sequence corrupts the entire average and I concluded that there is less meaning in presenting the train perplexities.

Let’s see the below charts.

The tensorboard charts of losses & perplexity.

As we can see, at a certain point, the validation loss/perplexity does not drop, which can be considered as an optimal point.

Since the training is finished, let’s see the inference results from the actual conversations between me and the model.


I used Nucleus Sampling(Top-$p$ Sampling) as the decoding algorithm like before.

When inferencing, labels parameter is not included, so only input_ids and token_type_ids are put into the model.

And the output from the model is also different, which provides the result logits at its initial position.

After conducting the softmax to this output, I made the model predict the next word at the target position with Nucleus Sampling.

I modified the Nucleus Sampling implementation, so I attach the modified version as follows.

Not only the overall code became cleaner, but also the edge case handling is added, which is always including the word with the highest probability to prevent all indices from converted into $0$.

I added this part by referring to the implementation by Thomas Wolf, the science lead at Huggingface, Inc[3].

from torch.nn import functional as F

import torch

def nucleus_sampling(input_ids, token_type_ids, input_len):
        output_ids = []
        for pos in range(input_len, max_len):
            output = model(
            )[0][:, pos-1]  # (1, V)
            output = F.softmax(output, dim=-1)  # (1, V)
            sorted_probs, sorted_idxs = torch.sort(output, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
            idx_remove = cumsum_probs > p
            idx_remove[:, 1:] = idx_remove[:, :-1].clone()
            idx_remove[:, 0] = False
            sorted_probs[idx_remove] = 0.0
            sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)
            probs = torch.zeros(output.shape).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
            idx = torch.multinomial(probs, 1)  # (1, 1)
            idx_item = idx.squeeze(-1).squeeze(-1).item()
            if idx_item == eos_id:
            input_ids = torch.cat((input_ids, idx), dim=-1)
            next_type_id = torch.LongTensor([[sp2_id]])
            token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
            assert input_ids.shape == token_type_ids.shape
        return output_ids

First, I evaluated the model with the top-p value of $0.9$.

You can see the dialogues I actually conducted with the model below.

The results when p is set to 0.9.

We can see that the results are quite decent, actually much better than I expected, but the main problem is that the model lacks long-term memory.

That is, it cannot generate responses coherent with previous utterances it has produced, which significantly degrades the overall engagement.

This is inevitable since the model takes a limited number of utterances as an input.

However, I can say that this is a good conversation model, considering the model size and the amount of training data.

Additionally, I reset the $p$ value to $0.8$, and checked the results by starting with the same topics and introductions.

The results when p is set to 0.8.

Well… The interpretation of the results may vary from individual to individual, but at least from my perspective, slightly more coherent results came out when the $p$ is $0.8$.

However, the limitations are still evident, for example, awkward responses are generated after a few turns passed.

Overcoming these limitations is still being actively studied in the open-domain dialogue field, and I am also interested in these obstacles.


Ok, let’s wrap up.

We saw that using the pre-trained GPT-2 made quite decent outputs as expected and the top-$p$ sample which is directly implemented operated properly.

However, we also knew that there are still several shortcomings so far.

I will finish this post after discussing the current limitations and possible solutions briefly.

  1. The previous history is not applied well.

    A common way to consider dialogue histories is concatenating previous utterances into one input sequence. However, this leads to an increase in memory usage if too many utterances are included. In addition, judging by the concept of the attention mechanism, consideration of a long context at once might degrade concentration on each word and decrease the performance. Although there are other ways, such as encoding each utterance hierarchically, like the ReCoSa (Zhang et al, 2019[4]), the information loss after encoding is known to be a cause of unsatisfactory results, compared to a simple concatenation. To cope with this, it is available to first summarize the history and concatenate each summarization, or to select only a few of the most relevant utterances by calculating the similarity-based scores using vector representations and then concatenating them in a natural language form. We can get some insights from several cases, such as Lan et al, 2020[5], Zhang et al, 2021[6], and Xu et al, 2021[7] which are also attached as references.

  2. The method for knowledge & commen sense grounded generation is needed.

    Unfortunately, I have not studied this topic deeply. However, knowledge-grounded dialogue generation is one of the most rigorously studied fields, and I am also trying to follow up. There are several pre-researched methodologies to inject knowledge into existing models, such as Kim et al, 2020[8], Zhao et al, 2020[9], Lewis et al, 2020[10], and Izacard et al, 2020[11]. Or we can consider the large-scaled pre-trained models with more parameters to take advantage of knowledge learned during training.

This is it for the open-domain multi-turn dialogue generation project using the GPT-2.

Although I’ll focus on other researches and TA works for the time being, I will try various improvements such as different decoding structures, adding user information and proper negative sampling, etc. if I have a chance.

I always welcome any feedback about what I need to improve or what I should fix.

Thank you.

[1] How to build a State-of-the-Art Conversational AI with Transfer Learning . (2019, May 9). https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313.
[2] huggingface/transfer-learning-conv-ai. https://github.com/huggingface/transfer-learning-conv-ai.
[4] Zhang, H., Lan, Y., Pang, L., Guo, J., & Cheng, X. (2019). Recosa: Detecting the relevant contexts with self-attention for multi-turn dialogue generation. arXiv preprint arXiv:1907.05339. https://arxiv.org/abs/1907.05339.
[5] Lan, T., Mao, X. L., Wei, W., & Huang, H. (2020). Which Kind Is Better in Open-domain Multi-turn Dialog, Hierarchical or Non-hierarchical Models? An Empirical Study. arXiv preprint arXiv:2008.02964. https://arxiv.org/abs/2008.02964.
[6] Zhang, Y., Ni, A., Yu, T., Zhang, R., Zhu, C., Deb, B., ... & Radev, D. (2021). An Exploratory Study on Long Dialogue Summarization: What Works and What's Next. arXiv preprint arXiv:2109.04609. https://arxiv.org/abs/2109.04609.
[7] Xu, J., Szlam, A., & Weston, J. (2021). Beyond goldfish memory: Long-term open-domain conversation. arXiv preprint arXiv:2107.07567. https://arxiv.org/abs/2107.07567.
[8] Kim, B., Ahn, J., & Kim, G. (2020). Sequential latent knowledge selection for knowledge-grounded dialogue. arXiv preprint arXiv:2002.07510. https://arxiv.org/abs/2002.07510.
[9] Zhao, X., Wu, W., Xu, C., Tao, C., Zhao, D., & Yan, R. (2020). Knowledge-grounded dialogue generation with pre-trained language models. arXiv preprint arXiv:2010.08824. https://arxiv.org/abs/2010.08824.
[10] Lewis, P., Perez, E., Piktus, A., Petroni, F., Karpukhin, V., Goyal, N., ... & Kiela, D. (2020). Retrieval-augmented generation for knowledge-intensive nlp tasks. arXiv preprint arXiv:2005.11401. https://arxiv.org/abs/2005.11401.
[11] Izacard, G., & Grave, E. (2020). Leveraging passage retrieval with generative models for open domain question answering. arXiv preprint arXiv:2007.01282. https://arxiv.org/abs/2007.01282.