By reading the documentation of OpenNMT-py, my understanding was that by using phrase_table if the predicted target token is unknown, the predicted target token will be replaced with the target token corresponding to the source token in the table. If the source token does not exist in the table, then the source token that had the highest attention weight will be copied.
But looking at the code, it seems like my understanding was wrong. It seems like if the predicted target token is unknown, the predicted target token will be replaced with the target token corresponding to the token that had the highest attention weight.
Code snippet is below:
if self.replace_unk and attn is not None and src is not None:
for i in range(len(tokens)):
if tokens[i] == tgt_field.unk_token:
_, max_index = attn[i][:len(src_raw)].max(0)
tokens[i] = src_raw[max_index.item()]
if self.phrase_table != "":
with open(self.phrase_table, "r") as f:
for line in f:
if line.startswith(src_raw[max_index.item()]):
tokens[i] = line.split('|||')[1].strip()
I had assumed, it would be something like this instead:
if self.replace_unk and attn is not None and src is not None:
for i in range(len(tokens)):
if tokens[i] == tgt_field.unk_token:
_, max_index = attn[i][:len(src_raw)].max(0)
tokens[i] = src_raw[max_index.item()]
if self.phrase_table != "":
with open(self.phrase_table, "r") as f:
for line in f:
if line.startswith(src_raw[i]):
tokens[i] = line.split('|||')[1].strip()
Notice that line if line.startswith(src_raw[max_index.item()]):
is replaced with if line.startswith(src_raw[i]):
Am I wrong in my assumption? Is this the desired behavior?