Attention on a specific word in the context

I want to modify the code so that when generating the word in pos t, the network focuses attention on a specific word say k on the source side. This is different from the soft attention implemented currently.

Can you provide some pointers on what needs to be modified? I could figure out that this code might need to be changed

local targetT = nn.Linear(dim, dim, false)(inputs[1]) -- batchL x dim
  local context = inputs[2] -- batchL x sourceTimesteps x dim
  dbg()
  -- Get attention.
  local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1
  attn = nn.Sum(3)(attn)
  local softmaxAttn = nn.SoftMax()
  softmaxAttn.name = 'softmaxAttn'
  attn = softmaxAttn(attn)
  attn = nn.Replicate(1,2)(attn) -- batchL x 1 x sourceL

Hello, how will you decide which word k to focus on? the attention vector is used to build a new context here:

local contextCombined = nn.MM()({attn, context}) -- batchL x 1 x dim

if you know to force the attention to a specific word - then you will just have to bypass everything and just get contextCombined as a selection in the context tensor - something like:

local contextCombined = context:narrow(2,k)

the other possibility is that you pass to the function a hand-build attention vector full of zero and 1 (and 1/n, if you want to look at several word) and then do a MM with context vector the same way than currently.

In any case, practically, do not modify the function - but create another module that you can easily plug-in the model building.

Hi Wabbit,

We implemented this in our old code base. I will try to put you in touch with the student who did it.

Hard attention is tough thing to implement because you will need to change from cross-ent to REINFORCE as a loss. A nice compromise is SparseMax attention (https://github.com/gokceneraslan/SparseMax.torch), which you you could add as a new module to substitute for global attention.

@srush: I might have put you off track because the original title of the post said “hard attention.” My bad. I didn’t mean a non-differentiable hard attention as the term is used in literature.

All I meant was that when predicting the word in position “k” I want to focus attention on the word in position “k” on the source side-so the attention in my case is a simple differentiable function of the context.

Oh I see! Makes sense. Yeah, so make a new module with the interface of GlobalAttention.lua that does what you would like, and then in model.lua use that module instead. I am not sure if we pass through the pos “t” but that should be possible to do.

@srush, @jean.senellart: thanks for the pointers. I think I’m stuck and need some more help.
Referring to line:237 in onmt/modules/Decoder.lua

  local outputs = self:net(t):forward(inputs) 
  local out = outputs[#outputs]

I think this is the place where I’ll need to pass in the step t. I’m unable to understand the flow here. What function does self:net(t):forward(inputs) call? Looks like it calls Decoder:forward() but that doesn’t make sense to me since it looks like an endless loop then.

Also I tried using mobdebug to trace the flow but if I put a breakpoint in GlobalAttention.lua at

local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1
  attn = nn.Sum(3)(attn)

that breakpoint is never triggered. Maybe I’m missing something trivial.

Hello! the logic of working with nngraph is a bit unsettling at first: the function in global attention is called very early to build the compute graph and that’s all - so you can not put any logic which is not part of the compute graph.
When we call self:net(t):forward(inputs) - inputs is propagated through the compute graph which will eventually reach the part of the attention network. The only way to pass some information is through “inputs” and it is why I was suggesting to add a normalized vector with sentence length size (where you can put 1 for the word you want to have attention to) in the inputs.
I will draft that in a branch today.

@jean.senellart: I’ve created the gist below. Planning to pass the current time step as inputs[4] through
self:net(t):forward(inputs)
Will also need to pass FOCUS_WORD_k flag as an argument to train.lua but first wanted to get the logic right.

GlobalAttention.lua lines 43:66

 

-  -- Apply attention to context.
-  local contextCombined = nn.MM()({attn, context}) -- batchL x 1 x dim
+    --TODO: pass this as a global var in configuration
+  FOCUS_WORD_k=true
+  if FOCUS_WORD_k~=true then
+    local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1
+    attn = nn.Sum(3)(attn)
+    local softmaxAttn = nn.SoftMax()
+    softmaxAttn.name = 'softmaxAttn'
+    attn = softmaxAttn(attn)
+    attn = nn.Replicate(1,2)(attn) -- batchL x 1 x sourceL
+    -- Apply attention to context.
+    local contextCombined = nn.MM()({attn, context}) -- batchL x 1 x dim
+  elseif FOCUS_WORD_k==true then
+    local contextCombined = context:narrow(2,inputs[4])-- batchL x 1 x dim
+  end
   contextCombined = nn.Sum(2)(contextCombined) -- batchL x dim
   contextCombined = nn.JoinTable(2)({contextCombined, inputs[1]}) -- batchL x dim*2
   local contextOutput = nn.Tanh()(nn.Linear(dim*2, dim, false)(contextCombined))

Decoder.lua lines 233: 242

  if self.train then
    self.inputs[t] = inputs
  end
  table.insert(inputs, t)
  local outputs = self:net(t):forward(inputs)
  local out = outputs[#outputs]
  local states = {}
  for i = 1, #outputs - 1 do
    table.insert(states, outputs[i])
  end

Hi Wabbit - thanks. give me one day and I will come back to you!
what I would like is to make sure we change as little as possible in the flow. (interestingly when we worked on modularizing opennmt - our internal example we took for testing the code organization was about changing the attention module - so it is good time to practice on that!)

@jean.senellart - I’m not sure what’s the priority for this in the team’s plan so I was thinking I could implement (maybe by breaking the modularity in my own branch) and get it code reviewed by someone more familiar with the setup. I need to move fast on it for one of my projects.

If that sounds ok please let me know if the changes in my earlier post are fine. Also I tried using mobdebug to understand the flow but breakpoints in GlobalAttention.lua are never triggered. What setup do you folks use?

Hi @Wabbit, I commited a draft implementation here: https://github.com/jsenellart-systran/OpenNMT/tree/HardAttention - so that you can move ahead.

What I did is the standard way: we are adding an additional tensor to the inputs of the decoder which then reaches attention. The initialization of this tensor for the moment is set to 1 for position t, and zero otherwise (note that the dimension of this tensor is batch_size*source_length - and timestep t is iterating on target sentence so it might go beyond the source sentence length).

It is not cleaned - I just modified the class GlobalAttention while you should create a new class HardAttention. Also, we need to condition the selection of the GlobalAttention or HardAttention when building decoder, and also for the passing of this additional input.

We would like to find a less intrusive implementation - but this is on our side! on yours, you can for the moment go ahead with this, it should do the work.

let me know if you have any question!

(Let’s not call this hard attention in the final version. Maybe IdentityAttention or InOrderAttention. )

I agree - I was thinking about FixedAttention - since it will allow to inject soft alignment from external alignment tool.

1 Like

@jean.senellart Thanks for the FixedAttention branch !

@jean.senellart: I’m having problems with generating translations on the master branch after I merged in https://github.com/jsenellart-systran/OpenNMT/tree/HardAttention

I also tried running translate using code in https://github.com/jsenellart-systran/OpenNMT/tree/HardAttention and even that has the same problem.

In your code you have removed softmaxAttn since HardAttention is just a vector of a single 1 and other 0s. So, I had to modify DecoderAdvancer.lua:80 as

  local softmaxOut = _

Also, the HardAttention code needs the position t, so I had to pass it in DecoderAdvancer.lua:78

decOut, decStates = self.decoder:forwardOne(inputs, decStates, context, decOut, t)

I’m still trying to understand the entire flow but could you please confirm if these changes make sense? LMK if you want to see the files on github so that you can diff.

for adding t parameter you are right, for the softmaxOut - we need to do something little bit different: we do need to get the fixedAttn instead so that the beam search goes fine. I am in plane for the next 11 hours but will try to commit a patch when I land.

fixedAttention is of dimensions --batchL X sourceL
What if we pass fixedAttention through a softmax like:



  local softmaxAttn = nn.SoftMax()
  softmaxAttn.name = 'softmaxAttn'
  attn = softmaxAttn(fixedAttention)--batchL X sourceL
  attn = nn.Replicate(1,2)(attn) -- batchL x 1 x sourceL

Then we can this SoftmaxAttn in DecoderAdvancer.lua:80 as

local softmaxOut = self.decoder.softmaxAttn.output

The only issue is that exp(0)=1 and exp(1)=2.71 are not very different so we might ideally want FLOAT_MIN that torch allows instead of 0 so that attention goes to almost 0 at the right places.

Hi @Wabbit - I put a patch on the branch, I am just getting the fixed attention vector. it seems very hard to test though because for regular translation task, we cannot really expect any good result. Do you have a specific use case for that?
Note also that source target sentence has a <BOS> token - so if you count on aligning manually source and target, you need to take this into account.

My task is (potentially) simpler. It’s about tagging each word in the source to say whether it represents a particular kind of entity-essentially Named Entity Recognition.

An example source-target pair will be:
Source: This shirt is blue
Target 0,0,0,1

This assumes that we are tagging for color . The motivation for fixed attention here is that there’s a 1:1 alignment between source: target (unlike the case of MT). Fixed Attention is a way of pushing in more information to the decoder ( beyond what’s captured by the final state of the encoder RNN)

Thanks for the explanation. I think you should explore something simpler than seq2seq here - even with fixed attention, you won’t be able to control target sentence length and encoder-decoder-attn approach is overkilling.
Basically what you need is a sequence-tagger - so you only need encoding layer+softmax - just like language model implementation, except that target is your NER tags. I am copying @josep.crego who is starting working exactly on that on our side.