This was asked in a previous question from 6 days ago, “How to vizualize attention weights”, but i haven’t been able to find an answer anywhere. When trying to get the attention weights, I get the following error before getting very far.
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
Here is the code.
import onmt
import onmt.inputters
import onmt.translate
import onmt.model_builder
from collections import namedtuple
Opt = namedtuple('Opt', ['models', 'data_type', 'reuse_copy_attn', "gpu"])
opt = Opt("/home/shankrenn/Desktop/hidden-att/model/hidden-2/seed-0/LSTMlang1_step_400.pt", "text",False,0)
fields, model, model_opt= onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})
And here is the trace.
Traceback (most recent call last):
File "<ipython-input-51-94c1f45c429f>", line 1, in <module>
runfile('/home/shankrenn/Desktop/hidden-att/graph_hidden_exp.py', wdir='/home/shankrenn/Desktop/hidden-att')
File "/home/shankrenn/anaconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 786, in runfile
execfile(filename, namespace)
File "/home/shankrenn/anaconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 110, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "/home/shankrenn/Desktop/hidden-att/graph_hidden_exp.py", line 46, in <module>
fields, model, model_opt= onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})
File "../../Documents/NMT/OpenNMT-py/onmt/model_builder.py", line 85, in load_test_model
map_location=lambda storage, loc: storage)
File "/home/shankrenn/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 387, in load
return _load(f, map_location, pickle_module, **pickle_load_args)
File "/home/shankrenn/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 549, in _load
_check_seekable(f)
File "/home/shankrenn/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 194, in _check_seekable
raise_err_msg(["seek", "tell"], e)
File "/home/shankrenn/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 187, in raise_err_msg
raise type(e)(msg)
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.