OpenNMT Forum

How to visualize attention weights?

Hi, I am pretty new to seq2seq models and OpenNMT-py. I am using OpenNMT for a summarization problem and was able to train a basic model using the examples. However, I tried to visualize the attention weights using the code mentioned in this thread and I am getting the following error:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

this is my codes:

import onmt
import onmt.inputters
import onmt.translate
import onmt.model_builder
from collections import namedtuple

# Load the model.
Opt = namedtuple('Opt', ['model', 'data_type', 'reuse_copy_attn', "gpu"])

opt = Opt("/Users/esalesky/projects/models/lstm_clean_acc_51.81_ppl_15.58_e13.pt", "text",False, 0)
fields, model, model_opt =  onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})

# Test data
data = onmt.inputters.build_dataset(fields, "text", None, use_filter_pred=False, src_path='/Users/esalesky/projects/data/fisher/test.es')
data_iter = onmt.inputters.OrderedIterator(
        dataset=data, device='cuda',
        batch_size=1, train=False, sort=False,
        sort_within_batch=True, shuffle=False)

# Translator
translator = onmt.translate.Translator(model, fields,
                                           beam_size=5,
                                           n_best=1,
                                           global_scorer=onmt.translate.GNMTGlobalScorer(0, 0, "none", "none"),
                                           gpu=True)

builder = onmt.translate.TranslationBuilder(
        data, translator.fields,
        1, False, None)

for j, batch in enumerate(data_iter):
        batch_data = translator.translate_batch(batch, data)
        translations = builder.from_batch(batch_data)
        print("src:", " ".join(translations[0].src_raw))
        print("tgt:", " ".join(translations[0].pred_sents[0]))
        print("idx:",str(j))
        print("-----")

Can you please add the full log with the code line of the error?

@Bachstelze @Nokeli I’m also getting the same error. Here is the trace.

Traceback (most recent call last):
File “/dscrhome/ds448/anaconda3/envs/freshenv/lib/python3.6/site-packages/torch/serialization.py”, line 191, in _check_seekable
f.seek(f.tell())
AttributeError: ‘dict’ object has no attribute ‘seek’

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “attention_v.py”, line 13, in
fields, model, model_opt = onmt.model_builder.load_test_model(pos_opt,{“reuse_copy_attn”:False})
File “/hpchome/carin/ds448/smalldata_normal/onmt/model_builder.py”, line 85, in load_test_model
map_location=lambda storage, loc: storage)
File “/dscrhome/ds448/anaconda3/envs/freshenv/lib/python3.6/site-packages/torch/serialization.py”, line 387, in load
return _load(f, map_location, pickle_module, **pickle_load_args)
File “/dscrhome/ds448/anaconda3/envs/freshenv/lib/python3.6/site-packages/torch/serialization.py”, line 549, in _load
_check_seekable(f)
File “/dscrhome/ds448/anaconda3/envs/freshenv/lib/python3.6/site-packages/torch/serialization.py”, line 194, in _check_seekable
raise_err_msg([“seek”, “tell”], e)
File “/dscrhome/ds448/anaconda3/envs/freshenv/lib/python3.6/site-packages/torch/serialization.py”, line 187, in raise_err_msg
raise type(e)(msg)
AttributeError: ‘dict’ object has no attribute ‘seek’. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.