Jaewoo Song
Jaewoo Song

Categories

  • Tech

Following the introductions that I posted last time, today let’s talk about the transformer model using the ReCoSa(the Relevant Contexts with Self-attention) structure, which is the first model for the multi-turn chatbot.

First, I’m gonna briefly talk about the data process method, then we are going to look at the model codes I implemented myself.

After that, I will show the experimental results and wrap up this post with the analysis and possible improvements.




Data processing

As I mentioned before, I used $4$ multi-turn dialogue datasets for this project.

But the format and purpose of each dataset are so different, which made me implement additional extraction codes to get pure utterances.

The details could be too long so I will omit how I extracted only dialogues from the raw data. (You can easily get to know by checking Huggingface’s dataset viewer page.)

In this post, I will show the final format of processed data files.


First I split each dataset by $0.85:0.15$ based on the number of conversations to make the train & validation set.

Then I made the data files into the below format by combining extracted dialogues from $4$ different datasets. (Eventually we have $2$ files, train and validation.)

The format of processed data file.


As you can see, I tokenized all sentences, converted them into token ids and saved before the actual training.

The GPT2 Tokenizer was used as I mentioned in the previous post.

When making a CustomDataset object later, I loaded the files and made pairs by matching consecutive utterances.


A tricky part was that each data has a little bit different grammatical rule.

Especially, the DailyDialog and the PersonaChat consist of letters in lower case only and all punctuations have whitespaces on both sides, which is far from the original English grammar.

I think that to make the datasets, a model trained on only lower-cased data might have been used but I need to adjust this since I wanted all utterances to follow the original grammar.

Therefore, an additional task to handle these deviations by tokenizing each utterance and modifying each token representation.

The function for this is as follows.

space = 'Ġ'

pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()
    
    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]
                
            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations \
                    		or (token_list[i+1][0] == space \
                                and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]
                        
        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 \
                		and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1
                
        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()
                
    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])
        
    return new_token_list


Even if you can’t understand what is happening above, there is no problem at all.

It is too hard-coded and ineffective even from my point of view…

Since GPT2 Tokenizer uses 'Ġ' as whitespace, it can be implemented by handling the space directly not so hard.


Anyway after this whole process, eventually I could get the data files consisting of only token ids.



Model

Let’s see the ReCoSa architecture again.

The description of the ReCoSa structure.


As I stated last time, the difference between the original transformer and the ReCoSa is the encoder which has the word-level encoding and history-level(here, I’m gonna call it “time-level”) encoding separately.

The basic modules such as the multi-head attention, the encoder layer and the decoder layer, etc. are almost the same as those in the codes which I used for my Neural Machine Translation project using the transformer.

If you have a lack of understanding of the basic transformer codes, it might be helpful to check my previous post(Neural Machine Translation with Transformer in Pytorch) and the GitHub repository(transformer-translator-pytorch).


Then let’s see the encoder more specifically.

I used a GRU(Gated Recurrent Units) for word-level encoding by using the last hidden vector from the last layer as the context vector.

Then I made the encoder class which conducts multi-head attention for the context vector concatenated with the positional encoding vector.

Unlike in the original transformer, the positional encoding vector is concatenated , not added.

I think that this might be to prevent the utterance embedding from contaminated by the positional embedding since it contains much more information than a word embedding has.

# Word Level LSTM components
word_level_rnn = nn.GRU(
    input_size=d_model,
    hidden_size=hidden_size,
    num_layers=gru_num_layers,
    dropout=(0.0 if gru_num_layers == 1 else gru_dropout),
    batch_first=True,
)
        
# Encoder
encoder = Encoder(
    hidden_size + d_model,
    d_ff,
    num_heads,
    dropout,
    encoder_num_layers
)


The process of making the context representation with the actual input can be implemented like below.

def src_embed(src_input):
    src_emb = embedding(src_input)  # (B, T, L, d_model)
    if use_gpt:
        src_emb = embedding_linear(src_emb)  # (B, T, L, d_model)
    max_len, d_model = src_emb.shape[2], src_emb.shape[3]
    last_hiddens = word_level_rnn(
        src_emb.view(-1, max_len, d_model)
    )[1][-1]  # (B*T, d_model)

    batch_size = src_emb.shape[0]
    src_emb = last_hiddens.view(batch_size, -1, d_model)  # (B, T, d_model)
    src_emb = time_pembedding(src_emb, cal='concat')  # (B, T, 2*d_model)

    return src_emb  # (B, T, 2*d_model)

src_emb = src_embed(src_input)  # (B, T, 2*d_model)
e_output = encoder(src_emb, e_mask)  # (B, T, 2*d_model)


I think it is not that difficult to understand this.

The reason why the shape of src_input is $(B, T, L)$($B$: Batch size, $T$: Max time step, $L$: Max sequence length) will be explained in the next section.


Anyway, after we get the context representation, the decoder operates like the basic transformer.

The decoder conducts masked multi-head attention to the sequence generated so far(when training, this would be the ground truth response for teaching forcing) and makes the response representation.

This goes into another multi-head attention with the encoder output.

This time, the response representation becomes the query and the encoder output becomes the key and the value.

The notable point is that the encoder output has the representation of each time step(history), which is different from the original where both encoder and decoder deal with token-level representations.

That is, we are going to get the attention scores between the vectors at token indexes in the response representation and the vectors at history indexes in the context representation.

This can be described as follows.

The process of self-attention mechanism between the context representation and the response representation.


After this context-response attention and concatenation between the result from each attention head, the output is sent to the rest of the layers, which is the same as those in the vanilla transformer.

I will omit the later parts since they work just like in the original generation tasks.



Training

As you know, we just have to make the model generate the target sentence putting the source sentence in traditional NLG tasks.

But in the multi-turn setting, this way cannot consider previous histories.

At first, I thought of simply using one dialogue itself as a unit and making several conversations into one batch, but to match the same number of time steps between each dialog, I had to include a lot of dummy utterances.

This eventually led to poor optimization and unnecessarily large data size, causing severe problems with the training time and memory.

So I looked into several other codes and concluded that I should have processed the input which is put into the model as follows.

The format for multi-turn dialogue generation training using the ReCoSa structure.


This is a quite intuitive way.

Assuming that the maximum time step is $T$ and the current input is $t$th utterance in a dialogue, then remaining history $1$ ~ $t-1$ would be the dialogue contexts.

I just made an input by combining these $t$ histories and adding dummies to make the number of utterances into $T$.

And to prevent the model from attending to these dummies, I implemented masks that covers up the utterances existing in time steps bigger than $t$.

The target sentence becomes the same shape of tensors, since it is just a response at $t$.

This is the same for the decoder mask.

The encoder mask is a little bit different since it should be history-level, not token-level, but it can be implemented easily as we just have to put boolean values distinguishing between before and after $t$.


Other additional works such as adding bos & eos token and padding are the same as those of the previous transformer cases.

And separating the target input and target output for teaching forcing is also the same.

The data processing codes in CustomDataset class is as follows.

def process_src(src_sent, max_len, bos_id, eos_id, pad_id):
    if len(src_sent) < max_len:
        src_input = src_sent + [eos_id]
        src_input += [pad_id] * (max_len - len(src_input))
    else:
        src_input = src_sent[:max_len]
        src_input[-1] =eos_id
            
    return src_input
    
    
def process_trg(trg_sent, max_len, bos_id, eos_id, pad_id):
    if len(trg_sent) < max_len:
        trg_output = trg_sent + [eos_id]
        trg_output += [pad_id] * (max_len - len(trg_output))
    else:
        trg_output = trg_sent[:max_len]
        trg_output[-1] = eos_id
            
    if len(trg_sent) < max_len:
        trg_input = [bos_id] + trg_sent
        trg_input += [pad_id] * (max_len - len(trg_input))
    else:
        trg_input = [bos_id] + trg_sent
        trg_input = trg_input[:max_len]
            
    return trg_input, trg_output


def make_encoder_mask(self, t, max_time):
    e_mask = [1 for i in range(t+1)] + [0 for i in range(max_time-t-1)]  # (T)

    return e_mask
    
def make_decoder_mask(self, trg_inputs, pad_id, max_len):
    d_masks = (trg_inputs != pad_id).unsqueeze(1)  # (N, 1, L)

    nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, L)
    nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
    d_masks = d_masks & nopeak_mask  # (N, L, L) padding false

    return d_masks


src_inputs = []  # (N, T, L)
trg_inputs = []  # (N, L)
trg_outputs = []  # (N, L)
e_masks = []  # (N, 1, T)
        
init = [pad_id] * max_len
history = [init for t in range(max_time)]  # (T, L)
num_time = 0

# lines: The list of total lines in the file.
for i, line in enumerate(tqdm(lines)):
    if line.strip() == dialogue_split_line:
        history = [init for t in range(max_time)]
        num_time = 0
    elif i+1<len(lines) and lines[i+1].strip() != dialogue_split_line:                    
    	if num_time < max_time:
        	src_sent = [int(token) for token in line.strip().split(' ')]
            trg_sent = [int(token) for token in lines[i+1].strip().split(' ')]

            src_input = self.process_src(src_sent, max_len, bos_id, eos_id, pad_id)
            trg_input, trg_output = self.process_trg(trg_sent, max_len, bos_id, eos_id, pad_id)
                    
            if num_time < max_time:
                history[num_time] = src_input
                e_mask = make_encoder_mask(num_time, max_time)
            else:
                history = history[1:] + [src_input]
                e_mask = make_encoder_mask(max_time, max_time)
                    
                num_time += 1
                    
                src_inputs.append(history)
                trg_inputs.append(trg_input)
                trg_outputs.append(trg_output)
                e_masks.append(e_mask)
                    
src_inputs = torch.LongTensor(self.src_inputs)  # (N, T, L)
trg_inputs = torch.LongTensor(self.trg_inputs)  # (N, L)
trg_outputs = torch.LongTensor(self.trg_outputs)  # (N, L)        
e_masks = torch.BoolTensor(self.e_masks).unsqueeze(1)  # (N, 1, T)
d_masks = make_decoder_mask(self.trg_inputs, pad_id, max_len)  # (N, L, L)


Conclusively, the shape of each data becomes like these after batchifying them.

  • src_inputs: $(B, T, L)$
  • trg_inputs: $(B, L)$
  • trg_outputs: $(B, L)$
  • e_masks: $(B, 1, T)$
  • d_masks: $(B,L,L)$


I think there is no need to talk about the details of training codes since they are not different from the basic generation tasks.

I used CrossEntroy loss function to train the model for the next token prediction using the train set and evaluated it with the validation set.

And I saved the best model based on the mean value of validation losses.

You can check each hyperparameter’s setting from README.md in the repository of this project linked above.


Now we can see the logs and plot chart of changes in train loss/validation loss.

The screenshot of training logs.


The changes of training loss values & validation loss values per each epoch.


Originally, the default number of total epochs is $20$, but to make the model more trained I increased it to $40$.

Nevertheless, both training loss and validation loss didn’t completely converge as you can see.

This is a very challenging task and the number of data is not enough in my point of view, so maybe it is still under-fit.

First, let’s see how the results look like in the next section and talk about this more specifically.



Results

Now, let’s the actual inference step.

We need a proper decoding algorithm for this and I used Nucleus Sampling(Top-$p$ Sampling).

Since this project handles an open-ended generation task where an input scope is quite different from that of the output, I wanted to consider the diversities of generated results.


Nucleus Sampling chooses the next word by extracting the subgroup which accounts for a certain mass portion, $p$, of total word distribution and sampling one randomly from it.

In this way, we can consider both language modeling quality and diversity of word selections at the same time as the sampling space is determined dynamically.

A more detailed explanation can be found in my previous post which can help you a lot.


The implementation of this algorithm is as follows.

trg_input = [bos_id]
trg_input += [pad_id] * (max_len-len(trg_input))
trg_input = torch.LongTensor(trg_input).unsqueeze(0) # (1, L)

output_ids = []

for pos in range(max_len):
    trg_emb = model.trg_embed(trg_input)  # (1, L, 2*d_model)
    d_mask = make_decoder_mask(trg_input, pad_id, max_len)

    # e_output: the output from the encoder (1, L, 2*d_model)
    # e_mask: the encoder mask (1, 1, L)
    d_output = model.decoder(trg_emb, e_output, e_mask, d_mask)  # (1, L, 2*d_model)

    # Make the decoder input to the vocab size.
    output = F.softmax(model.output_linear(d_output), dim=-1)  # (1, L, vocab_size)
    output = output[:,pos]  # (B, vocab_size)

    sorted_probs, sorted_idxs = torch.sort(output, descending=True)
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, vocab_size)
    idx_remove = cumsum_probs > p
    sorted_probs[idx_remove] = 1e-8  # substancially small value
    sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, vocab_size)

    # Random sampling
    seed = int(time.time())
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    probs = torch.zeros(output.shape).scatter_(-1, sorted_idxs, sorted_probs)  # (1, vocab_size)
    idxs = torch.multinomial(probs, 1).squeeze(-1)

    if pos < max_len-1:
        trg_input[:, pos+1] = idxs

    output_ids.append(idxs.squeeze(0).item())    
    if idxs.squeeze(0).item() == eos_id:
        break

if output_ids[-1]== eos_id:
	output_ids = output_ids[:-1]


It is not that complicated, so let me explain it briefly.

First, I selected the word distribution of the position I want and sorted them by probability in descending order.

Then using torch.cumsum method, I calculated the cumulative summations and excluded the indexes where this sum is bigger than p by assigning a substantially small number to them.

Next, I got a new sampling space after normalizing the distribution by dividing it into the total probability sum, as we can check in the original paper.

Then I implemented torch.multinomial function to get a word randomly based on these probabilities and repeated this procedure until I get the end token.


It is time to see the results.

I tried several times to talk with the bot and let me show you $5$ conversations among them.

The results of the conversations with the trained chabot.


I felt uneasy from the training process and as expected, the results were not very good.

Let alone the multi-turn context understanding, a response from every single turn is too unstable, which is not relevant or has severe repetitions just like the cases using the high-probability based decoding methods.

I tried several more times but stopped eventually because I thought I wouldn’t get any more meaningful outcomes.



Discussions

Let’s talk about the most important things here.

Why did we get the results like above?

Before we start, I’m gonna assume that there is no error in my understanding of the paper and Pytorch implementations. (Because if one of them is a major cause, then this is a stupid mistake of me and there is no meaning of discussions at all…)

I summarized three possible reasons.


  1. Every setting is fine. But simply the model is under-trained.

    As we can see from the loss chart, the training loss and validation loss didn’t converge enough and until now, the model is still being trained. So it is possible to make the model more competent if I invest more time. I will update the results after more training.

  2. The training data is not enough.

    This should be considered seriously in most Deep Learning projects, but I don’t agree this time. Because the authors show quite decent results in the paper with a similar size of data.

  3. There can be a problem with the hyperparameter setting.

    At first, I was also suspicious about this issue, but I don’t think there is a noticeable difference with other repositories that have produced fine results. Even if I changed several arguments such as max_len and max_time, the improvement was not that remarkable. My model is quite more complex than other models I referred to, but again I don’t think it is not the cause since even the overfitting is not happening. Conclusively, this might not be the main issue.


Eventually the possibility #1 is the most suspicious and I plan to see the results after additional training, finishing this post here.

And if I experiment the GPT-2 implementation, we can check how optimization is achieved faster and more sufficiently.

That can tell whether the under-fitting to the response generation task itself is the main reason or not.

Additionally, I asked the author of the repository I referred to and he said that he also got unpleasant outcomes when he applied the codes from the original authors of the paper first, but after implementing the GRU decoder, instead of the transformer’s, he was able to improve the model’s performance significantly.

Honestly, I don’t understand why these changes happened, so I will also think about it more thoroughly later.



So this is the end of the post about my first try to make a multi-turn dialogue generation model.

Although the results are not satisfactory, I had a meaningful time thinking about various things and I don’t conclude that this approach is over completely.

As mentioned earlier, since we can get another explanation after implementing the other model and also can get unexpected improvements after additional training or structure changes, I will try again and again and update the outcomes if they are noticeable.

First, I will wrap up this post and talk about the chatbot model with GPT-2 next time.



(Updated at 2020-11-12)

I will add this content to post about additional results.

Before talking about the additional training, I have to tell you honestly that there is an error in my Nucleus Sampling codes.

Due to false updating of the random seed, it seems that the same word came out repetitively when sampling, and I noticed that after fixing it, the number of repetitions has greatly decreased.


Anyway, I trained the model a little more and saw that the losses converged at about epoch $53$, where the train loss is $2.7714$ and the validation loss is $3.7167$.

And I could get these conversations as results.

The results of the conversations with the trained chabot after additional training.


Although they were chosen as decent dialogues after several tries, we can see that more completed sentences were generated.

Of course, there are still some expressions that are out of context and only short interactions are available.

But obviously improvements have been made.

I think I’m done here and it’s time to move on to the next approach, the multi-turn chatbot using GPT-2.



Li, Y., Su, H., Shen, X., Li, W., Cao, Z., & Niu, S. (2017). Dailydialog: A manually labelled multi-turn dialogue dataset. arXiv preprint arXiv:1710.03957. https://arxiv.org/abs/1710.03957.
Rashkin, H., Smith, E. M., Li, M., & Boureau, Y. L. (2018). Towards empathetic open-domain conversation models: A new benchmark and dataset. arXiv preprint arXiv:1811.00207. https://arxiv.org/abs/1811.00207.
Zhang, S., Dinan, E., Urbanek, J., Szlam, A., Kiela, D., & Weston, J. (2018). Personalizing dialogue agents: I have a dog, do you have pets too?. arXiv preprint arXiv:1801.07243. https://arxiv.org/abs/1801.07243.
Smith, E. M., Williamson, M., Shuster, K., Weston, J., & Boureau, Y. L. (2020). Can You Put it All Together: Evaluating Conversational Agents' Ability to Blend Skills. arXiv preprint arXiv:2004.08449. https://arxiv.org/abs/2004.08449.
Zhang, H., Lan, Y., Pang, L., Guo, J., & Cheng, X. (2019). Recosa: Detecting the relevant contexts with self-attention for multi-turn dialogue generation. arXiv preprint arXiv:1907.05339. https://arxiv.org/abs/1907.05339.
devJWSong/transformer-translator-pytorch. https://github.com/devJWSong/transformer-translator-pytorch.
gmftbyGMFTBY/MultiTurnDialogZoo. https://github.com/gmftbyGMFTBY/MultiTurnDialogZoo.