I add the knowledge distillation code to the trainer.py , the main changes is the following,
if self.kd:
src_t = copy.deepcopy(src)
tgt_t = copy.deepcopy(tgt)
src_lengths_t = copy.deepcopy(src_lengths)
bptt_t = copy.deepcopy(bptt)
with torch.no_grad():
outputs_t, attns_t = self.teacher_model(src_t, tgt_t, src_lengths, bptt=bptt)
# 2. F-prop all but generator.
if self.accum_count == 1:
self.optim.zero_grad()
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align)
bptt = True
# 3. Compute loss.
try:
loss, batch_stats, probs = self.train_loss(
batch,
outputs,
attns,
normalization=normalization,
shard_size=self.shard_size,
trunc_start=j,
trunc_size=trunc_size)
kd_loss = None
if self.kd:
#outputs = outputs.float()
#outputs_t = outputs_t.float()
outputs = outputs.view(-1, outputs.size(2))
outputs_t = outputs_t.view(-1, outputs_t.size(2))
#prob = self.model.generator(outputs)
with torch.no_grad():
logprobs_t = self.teacher_model.generator(outputs_t)
if self.kd_loss == "kl":
probs_t = torch.exp(logprobs_t)
else:
probs_t = logprobs_t
kd_loss = self.kd_loss_func(probs, probs_t)
if self.kd_attn_loss_func is not None:
attn = torch.sum(attns['head'], dim=1).float()
attn_t = torch.sum(attns_t['head'], dim=1).float()
attn_loss = self.kd_attn_loss_func(
attn,
attn_t
)
#attn_loss = attn_loss.half()
else:
attn_loss = None
kd_loss = self.kd_weight * pow(self.kd_temperature, 2) * kd_loss
if attn_loss is not None:
kd_loss += self.kd_attn_weight * pow(self.kd_attn_temperature, 2) * attn_loss
#batch_stats.update_kd(kd_loss.clone())
if loss is not None:
if kd_loss is not None:
loss += kd_loss
self.optim.backward(loss)
elif kd_loss is not None:
self.optim.backward(kd_loss)
And the training stops after several steps like 20000 or 50000. After I reload the last checkpoint file and continue training, it can train another 20000 or 50000 steps and stop again.
Is there something wrong with my code? Any reply is appreciate. Thx!