Hi, I am building a web service that process one sentence per request.
My approach:
Write one sentence to text file (tmp file will be deleted).
Reuse code in translate.py (which take file path as input)
Delete the tmp file
Output result sentence
Here is code. Please review and let me know how can I improve or any issue.
def _translate(self, text):
"""
Take input
Forward to model
Return output
:param text: string
:return: string of result from model
"""
model = self.model
fields = self.fields
# set number of thread -> reduce CPU load (on shared server)
if opt.gpu == -1:
torch.set_num_threads(2)
print("num threads ", torch.get_num_threads())
try:
# create tmp file contained text
src_file = self._create_tmp_file(text)
# Test data
data = onmt.io.build_dataset(fields, opt.data_type,
src_file, None,
use_filter_pred=False)
# Sort batch by decreasing lengths of sentence required by pytorch.
# sort=False means "Use dataset's sortkey instead of iterator's".
data_iter = onmt.io.OrderedIterator(
dataset=data, device=opt.gpu,
batch_size=opt.batch_size, train=False, sort=False,
sort_within_batch=True, shuffle=False)
scorer = onmt.translate.GNMTGlobalScorer(opt.alpha,
opt.beta,
opt.coverage_penalty,
opt.length_penalty)
onmt_translator = onmt.translate.Translator(
model, fields,
beam_size=opt.beam_size,
n_best=opt.n_best,
max_length=opt.max_length,
cuda=opt.cuda,
global_scorer=scorer,
beam_trace=opt.dump_beam != "",
min_length=opt.min_length,
stepwise_penalty=opt.stepwise_penalty,
block_ngram_repeat=opt.block_ngram_repeat,
ignore_when_blocking=opt.ignore_when_blocking)
builder = onmt.translate.TranslationBuilder(
data, onmt_translator.fields,
opt.n_best, opt.replace_unk, None)
output_pred = []
for batch in data_iter:
batch_data = onmt_translator.translate_batch(batch, data)
translations = builder.from_batch(batch_data)
for trans in translations:
n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:opt.n_best]]
output_pred.append(n_best_preds)
# delete tmp file
self._delete_tmp_file(src_file)
except Exception as e:
# delete tmp file
self._delete_tmp_file(src_file)
return str(e)
return output_pred[0][0]
def _create_tmp_file(self, text):
"""create temp file"""
fname = str(uuid.uuid1())+".txt"
mfile = open(fname, 'w')
mfile.write(text)
mfile.close()
return fname
def _delete_tmp_file(self, fname):
"""
remove temp file create by _create_tmp_file
:param fname: file path
:return:
"""
try:
os.remove(fname)
except OSError as e:
print("cannot remove tmp file")
print(e)
pass