After adding knowledge distillation, training stops after several steps

I add the knowledge distillation code to the trainer.py , the main changes is the following,

                if self.kd:
                    src_t = copy.deepcopy(src)
                    tgt_t = copy.deepcopy(tgt)
                    src_lengths_t = copy.deepcopy(src_lengths)
                    bptt_t = copy.deepcopy(bptt)
                    with torch.no_grad():
                        outputs_t, attns_t = self.teacher_model(src_t, tgt_t, src_lengths, bptt=bptt)
                # 2. F-prop all but generator.
                if self.accum_count == 1:
                    self.optim.zero_grad()
                outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
                                            with_align=self.with_align)
                bptt = True
                         
                # 3. Compute loss.
                try:     
                    loss, batch_stats, probs = self.train_loss(
                        batch,
                        outputs,
                        attns,
                        normalization=normalization,
                        shard_size=self.shard_size,
                        trunc_start=j,
                        trunc_size=trunc_size)
                    kd_loss = None
                    if self.kd:
                        #outputs = outputs.float()
                        #outputs_t = outputs_t.float()
                        outputs = outputs.view(-1, outputs.size(2))
                        outputs_t = outputs_t.view(-1, outputs_t.size(2))
                        #prob = self.model.generator(outputs)
                        with torch.no_grad():
                            logprobs_t = self.teacher_model.generator(outputs_t)
                            if self.kd_loss == "kl":
                                probs_t = torch.exp(logprobs_t)
                            else:
                                probs_t = logprobs_t
                        kd_loss = self.kd_loss_func(probs, probs_t)
                         
                        if self.kd_attn_loss_func is not None:                                                                 
                            attn = torch.sum(attns['head'], dim=1).float()
                            attn_t = torch.sum(attns_t['head'], dim=1).float()
                            attn_loss = self.kd_attn_loss_func(
                                attn,
                                attn_t
                            )
                            #attn_loss = attn_loss.half()
                        else:
                            attn_loss = None
                        kd_loss = self.kd_weight * pow(self.kd_temperature, 2) * kd_loss
                        if attn_loss is not None:
                            kd_loss += self.kd_attn_weight * pow(self.kd_attn_temperature, 2) * attn_loss
                         
                        #batch_stats.update_kd(kd_loss.clone())
                         
                    if loss is not None:
                        if kd_loss is not None:
                            loss += kd_loss
                        self.optim.backward(loss)
                    elif kd_loss is not None:
                        self.optim.backward(kd_loss)

And the training stops after several steps like 20000 or 50000. After I reload the last checkpoint file and continue training, it can train another 20000 or 50000 steps and stop again.
Is there something wrong with my code? Any reply is appreciate. Thx!

What do you mean by “the training stops”? Does it freeze? Do you get an error? If so, what’s the trace?

Thank you for your reply. To be precise, the training is freezing. It did not raise an error. The log is like this

[2021-09-04 05:02:35,546 INFO] Step 190100/500000; acc:  64.91; ppl:  4.92; xent: 1.59; kd_xent: 0.00; lr: 0.00020; 88622/73005 tok/s;  20509 sec
[2021-09-04 05:02:35,546 INFO] Step 190100; kd_xent: 0.00
[2021-09-04 05:04:13,589 INFO] Step 190200/500000; acc:  64.99; ppl:  4.90; xent: 1.59; kd_xent: 0.00; lr: 0.00020; 112595/92828 tok/s;  20607 sec
[2021-09-04 05:04:13,589 INFO] Step 190200; kd_xent: 0.00

After several hours, there is no more log informations, and the training program did not stop just stuck there.

Can you check the kernel termination errors? Run this in the Terminal:

nano /var/log/kern.log

or the command:

dmesg

This will be a little difficult as my task is not running locally, but on a virtual environment of our platform, and I can’t directly access the environment.