About "A Deep Reinforced Model for Abstractive Summarization"

I’ve been working on the paper Paulus et al., (2017) - A Deep Reinforced Model for Abstractive Summarization for quite a while now, it’s time to present it. The code is currently on a drmas branch on my fork repo https://github.com/pltrdy/OpenNMT-py/tree/drmas (there’s also a PR https://github.com/OpenNMT/OpenNMT-py/pull/319)


TL;DR: the model uses two kind of attention: i) temporal over encoder states; ii) over previous decoder states, and a pointer generator similar to See et al. (2017). The decoder is run twice to get two predictions sequences and two losses. The RL policy is to maximise the ROUGE score of the sampled output. The model finally optimizes a mixed objective that combine both loss.

Disclaimer: this code has been developped in a very experimental way. For convenience, I tried to put everything I needed in a single file in order not to interfere too much with the rest of the project. This means, that some refactoring needs to be discussed and done, in order to split the file into different ones and factor some functions.

Notations


  • transpose with T: ((X)T)T = X
  • encoder RNN states: he_i for each input i
  • decoder RNN states: hd_t for each timestep t

Model


  • base attention: the model uses a bilinear attention: attention(ht, hi) = (ht)T . W_attn . h_i

  • Intra-encoder temporal attention: attention between the decoder state hd_t and each of the encoder states he_i. There is a softmax normalization over previous attention scores afterward (temporal attention).

  • Intra-decoder attention: attention between the decoder state hd_t and each of the previous decoder states hd_t' (t' < t).

  • Pointer generator: Similar to the CopyGenerator, this model shares weights with the decoder embedding as described in sect. 2.4.

  • Exposure bias reduction: At each timestep, we feed the grount truth token (i.e. target) into the decoder. To avoid exposure bias, we sometime feed the predicted token instead (with a probability of 0.25 by default). We chose, as a prediction, the token with the highest probability (greed search).

  • Sampled output: for the reinforcement learning part, we run the decoder a second time and uses the prediction as decoder input at each time steps. The prediction here is sampled from the output distribution. The loss here, is, at each timestep, the probability of the sampled token.

  • Mixed objective: We calculate the ROUGE score of both predictions (greedy and sampled) and maximises the probability of the sampled one depending on the score difference. In short, if the sampled prediction is scored higher than the greedy one, then we want to maximise its probability. The final loss is then loss = gamma * loss_rl + (1 - gamma) * loss_ml. Therefore, gamma == 0 means ML only, gamma == 1 means RL only.

  • Trigram repetition avoidance: at test time, we penalize (set to 0) outputs that would produce a trigram that as already been decoded.

Classes


  • RTrainer: a modified version of the trainer. There isn’t that much changes, so both could be merged.
  • ReinforcedModel: takes an encoder and decoder (that must be a ReinforcedDecoder), and gamma it runs the models.
  • PartialEmbedding: Usefull to share part of the source embedding. It allows to share information while using a smaller target embedding.
  • ReinforcedDecoder
  • IntraAttention: the same class is used by both attention, the difference being in the parameter temporal
  • PointerGenerator: similar to CopyAttention. It involves the target embedding weights.
  • EachStepGeneratorLossCompute: Similar to CopyGeneratorLossCompute in the context that it will be called at each decoding step. Initially developped to implement a better way of collapsing scores (which is now implemented I guess)
  • RougeScorer: utility class to calculate Rouge scores. It requires rouge (a python implementation) that is faster than the PERL script.

Options


Preprocessing

  • -trunc_tgt_vocab TGT_VOCAB_SIZE: used to truncate the merged vocabulary (only effective with -share_vocab).

Training

  • -reinforced (bool)
  • -partial_embedding (bool)
  • -reinforced_gamma GAMMA (value in [0, 1])

Inference

  • -avoid_trigram_repetition (bool)

Usage


I set some variables:

  • data: path to the data folder
  • root: path to the experiment folder (to save checkpoint, processed data and inference output)
  • gpu: gpu ID
  • bs: batch size

Preprocessing

  python preprocess.py \
      -train_src $data/train.src.txt \
      -train_tgt $data/train.tgt.txt \
      -valid_src $data/valid.src.txt \
      -valid_tgt $data/valid.tgt.txt \
      -save_data $root/data \
      -src_seq_length 800 \
      -tgt_seq_length 100 \
      -dynamic_dict \
      -share_vocab \
      -src_vocab_size 150000 \
      -trunc_tgt_vocab 50000 \
      -save_data $root/data

Training

    python train.py -save_model $root/model \
                    -word_vec_size 256 \
                    -rnn_size 256 \
                    -layers 1 \ 
                    -rnn_type "LSTM" \
                    -encoder_type brnn \
                    -epochs 16 \
                    -seed 777 \
                    -batch_size $bs \
                    -max_grad_norm 2 \ 
                    -gpuid $gpu \
                    -optim "adam" \
                    -learning_rate 0.001 \
                    -reinforced \
                    -partial_embedding \
                    -reinforced_gamma 0.9984 \
                    -data $root/data 

Inference

  best_model=$(ls -lsrt $root/model*.pt | tail -n 1 | awk '{print $NF}')
  echo "Loading: $best_model"
  python translate.py -model "$best_model" \
                      -gpu "$gpu" \
                      -batch_size 1 \
                      -verbose \
                      -beam_size 5 \
                      -output $root/pred.txt \
                      -avoid_trigram_repetition \
                      -src $data/test.src.txt

(I’m using batch_size 1 for historical reason, we may be able to use higher values now)

Scoring


Using my wrapper files2rouge: files2rouge $data/test.tgt.txt $root/pred.txt:

Anon

---------------------------------------------
1 ROUGE-1 Average_R: 0.38155 (95%-conf.int. 0.37910 - 0.38405)
1 ROUGE-1 Average_P: 0.39103 (95%-conf.int. 0.38833 - 0.39361)
1 ROUGE-1 Average_F: 0.37370 (95%-conf.int. 0.37173 - 0.37578)
---------------------------------------------
1 ROUGE-2 Average_R: 0.15279 (95%-conf.int. 0.15069 - 0.15509)
1 ROUGE-2 Average_P: 0.15880 (95%-conf.int. 0.15633 - 0.16114)
1 ROUGE-2 Average_F: 0.15050 (95%-conf.int. 0.14840 - 0.15274)
---------------------------------------------
1 ROUGE-L Average_R: 0.35107 (95%-conf.int. 0.34863 - 0.35337)
1 ROUGE-L Average_P: 0.36024 (95%-conf.int. 0.35769 - 0.36281)
1 ROUGE-L Average_F: 0.34406 (95%-conf.int. 0.34208 - 0.34610)

With entity tags replaced:

---------------------------------------------
1 ROUGE-1 Average_R: 0.38964 (95%-conf.int. 0.38706 - 0.39215)
1 ROUGE-1 Average_P: 0.40079 (95%-conf.int. 0.39813 - 0.40336)
1 ROUGE-1 Average_F: 0.38277 (95%-conf.int. 0.38073 - 0.38486)
---------------------------------------------
1 ROUGE-2 Average_R: 0.16819 (95%-conf.int. 0.16608 - 0.17052)
1 ROUGE-2 Average_P: 0.17450 (95%-conf.int. 0.17197 - 0.17681)
1 ROUGE-2 Average_F: 0.16572 (95%-conf.int. 0.16356 - 0.16782)
---------------------------------------------
1 ROUGE-L Average_R: 0.35898 (95%-conf.int. 0.35664 - 0.36132)
1 ROUGE-L Average_P: 0.36974 (95%-conf.int. 0.36721 - 0.37224)
1 ROUGE-L Average_F: 0.35290 (95%-conf.int. 0.35088 - 0.35501)

Results


I ran it on anonymised CNN/DM dataset (with some @entity as described in Nallapati 2016). I may release the code I used for this dataset. I scored both outputs (anon with @entity and un-anonymised with entities replaced by corresponding words. I get, after 4 epochs (without improvement afterward):


Discussions


  • about ROUGE scoring: I only ran experiment using the product of ROUGE-1, ROUGE-2 and ROUGE-L (F1 metric). I’m not sure how relevant it is / if something else could make more sense.
  • I haven’t tried much values of gamma. I just took the one suggested in the paper. Also, using RL only (i.e. gamma == 1) raised some issues (need to try again).
  • some other tricks (not yet implemented) are described in the paper, that may lead to further improvements
  • there is some notes in the code i.e. comments like # NOTE (or even # TODO). It usually points to part of code that may need to be discussed i.e. some hacky things or part of code it had to set up to get the whole thing work.
  • contributions are welcome! both by suggesting edit or sharing your experiments!
1 Like

Hi, pltrdy!
Interesting model! What was your experience with it? Did you get good results?
What order of perplexity values did you reach?

Bye,

Manuel

Hi,

I’ve been discussing stuff about it in the related PR: https://github.com/OpenNMT/OpenNMT-py/pull/319.
Long story short, I’m having troubles reproducing it. I’ve been experimenting again recently since someone got interested. This is super experimental, there’s some work to be done to make it work.

The base model (either RL only or ML only) may work. There’s not much I’m really sure about.