Adding 2-gram conv_sub_layer in transformer en/decoder layer but get CUDA oom Error

I 'm trying to adding a 2-gram conv_sub_layer into transformer en/decoder layer, parameters increasing little(a conv kernal and a linear in each layer of en/decoder).
But I faced serious CUDA oom trouble. It looks like some intermediate variables unnecessarily stay in the CUDA memory after the data stream goes across each layer of the en/decoder (cause during the batches goes on, the used memory in CUDA keeping increase until it collapse, but not fluctuating stably, as the normal status).
To resolve the problem, I checked my code strictly to try to find any detail I missed, and did some revising (including del every intermediate variables I can find once they are not used anymore, in the forward function, to decrease the used memory. and revise the model to reduce parameters), but situation not improved.
There is nothing I can do to resolve this trouble, except ask for help in the forums.
here is the code:

def forward(self, input_embed, mask=None, future_mask=None):
    """
    Compute the context vector and the attention vectors.

    Args:
       input (FloatTensor): set of `embeded_input` vectors ``(batch, input_len, dim)``
       mask: binary mask 1/0 indicating which keys have
           zero / non-zero attention ``(batch, 1, key_len)`` if fututr_mask is None(for encoder)
                                     ``(batch, key_len, key_len)`` if fututr_mask is True(for decoder)
    Returns:
       (FloatTensor):
       * 2-gram tokens vectors ``(batch, input_len, dim)``
    """

    batch_size = input_embed.size(0)
    input_len = input_embed.size(1)

    for i in range(input_len):
        # step 1: Project embed_vectors_to_conv
        to_conv = two_gram_conv_input_generator(input_embed, i).unsqueeze(2)
        # [batch, lenth, 1, 2 x dim]
        # print("# step 1: ", to_conv.size())

        # step 2: conv
        conved_input = to_conv * self.conv_kernal
        # [batch, lenth, conv_kernals, 2 x dim]
        # print("# step 2: ", conved_input.size())
        del to_conv

        # step 3: GLU
        conved_input = nn.functional.glu(conved_input)
        # [batch, lenth, conv_kernals, dim]
        # print("# step 3: ", conved_input.size())

        # step 4: sum
        scores = torch.sum(conved_input, dim=-1)
        # [batch, lenth, conv_kernals]
        # print("# step 4: ", scores.size())

        # step 4.5: mask
        if mask is not None and future_mask is None:
            scores = scores.masked_fill(mask.transpose(1, 2), -1e18)
            # [batch, lenth, conv_kernals] mask [batch, length, 1]
        elif future_mask is not None:
            scores = scores.masked_fill(mask[:, i, :].unsqueeze(-1), -1e18)
        # print("# step 4.5: ", scores.size())

        # step 5: transpose + softmax
        scores = self.softmax(scores.transpose(1, 2).unsqueeze(-1)).to(input_embed.dtype).squeeze()
        # [batch, conv_kernals, lenth]
        # print("# step 5: ", scores.size())

        #step 6: scores_sum
        drop_attn = self.dropout(scores)
        del scores
        input_sum = torch.matmul(drop_attn.unsqueeze(2), conved_input.transpose(1, 2)).squeeze()
        # [batch, conv_kernals, 1, lenth] mm [batch, conv_kernals, lenth, dim] = [batch, conv_kernals, dim]
        # print("# step 6: ", input_sum.size())
        del conved_input
        del drop_attn

        # step 7: linear(conv_kernals, 1)
        input_sum = self.final_linear(input_sum.transpose(1, 2)).transpose(1, 2) # [batch, 1, dim]
        # input_sum = input_sum.contiguous()\
        #     .view(batch_size, -1, self.conv_kernals * self.model_dim)
        # [batch, lenth, conv_kernals x dim]
        # input_sum = self.final_linear(input_sum) # [batch, lenth, dim]
        # # print("# step 7: ", input_sum.size())
        if i == 0:
            output = input_sum
        else:
            output = torch.cat((output, input_sum), 1)
        del input_sum
    # if mask is not None:
    #     output = output.masked_fill(mask.transpose(1, 2), -1e18)
    #     # [batch, lenth, dim] mask [batch, length, 1]

    return output

and part of the Error description:

Traceback (most recent call last):
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/bin/train.py”, line 148, in run
single_main(opt, device_id, batch_queue, semaphore)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/train_single.py”, line 143, in main
valid_steps=opt.valid_steps)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/trainer.py”, line 266, in train
report_stats)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/trainer.py”, line 389, in _gradient_accumulation
with_align=self.with_align)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 550, in call
result = self.forward(*input, **kwargs)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/models/model.py”, line 51, in forward
with_align=with_align)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 550, in call
result = self.forward(*input, **kwargs)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/decoders/transformer.py”, line 329, in forward
with_align=with_align)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 550, in call
result = self.forward(*input, **kwargs)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/decoders/transformer.py”, line 95, in forward
output, attns = self._forward(*args, **kwargs)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/decoders/transformer.py”, line 155, in _forward
conv_norm = self.Multi2GramConv(input_norm, mask=dec_mask, future_mask=True)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 550, in call
result = self.forward(*input, **kwargs)
File “/home/caiy/miniconda3/envs/LLQ/lib/python3.6/site-packages/onmt/modules/multi_headed_attn.py”, line 307, in forward
conved_input = to_conv * self.conv_kernal
RuntimeError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 1; 10.76 GiB total capacity; 8.19 GiB already allocated; 10.12 MiB free; 9.06 GiB reserved in total by PyTorch)

It seems this is better asked on the PyTorch forum as you are not using OpenNMT-py features specifically.

At first glance, I would suggest to avoid looping over the time dimension. Can you find a way to compute all timesteps at once?