Cuda error out of memory fix

Hi, there are some way to avoid Cuda error out of memory in inference? Every time I send a large batch or a batch with so long sentences, I get this error. There are other libraries that can manage this by changing dynamically the batch size, like hugging face:

from accelerate import find_executable_batch_size

@find_executable_batch_size(starting_batch_size=starting_batch_size)

Am I doing the things wrong, or this is intended? What is the best way to use those models in inference when you receive large batches of data?

I am using it like this:

translations = self._model.translate_batch(
source_sents_subworded, batch_type=“tokens”, max_batch_size=2048,
beam_size=5
)

Hi,

You added the #opennmt-py tag, but it looks like you are using #ctranslate2, right?

What version of CTranslate2 are you using?

yes sorry, I also have this problem with Open NMT, but I use version 1.1.0 and 1.2.0. I am using the last version of c2translate. I just pip install c2translate yesterday or the day before.

How big is your model and how much GPU memory is available?

With CTranslate2 I’m unsure how you can get OOM errors with this configuration since long sentences are truncated after 1024 tokens by default.

I am using NLLB 3.3B 8b, that is ~4 GB. And was doing testing adding segments until I get an error, that is 110 segments. My device is a 12 GB GPU.

I did something like this to go around that annoying error (I know it is pretty awful implementation, but it works):

beam_size = 5
batch_size = 4096

    # Translate the source sentences
    while batch_size >= 64:
        try:
            translations = self._model.translate_batch(
                    source_sents_subworded, batch_type="tokens", max_batch_size=batch_size, 
                    beam_size=beam_size, target_prefix=target_prefix
            )
            batch_size = 0

        except RuntimeError as e:
            print(e)
            msg = f"Tokens batch size is too large, reducing it from {batch_size} to {batch_size//2}."
            logging.info(msg)
            with open("batch_size_problems.log", "a") as f:
                f.write(msg+"\n")
                f.write("MAX CHARACTERS SENTENCE:  " + str(len(max(batch, key=len)))+"\n")
                f.write("LEN BATCH TO TRANSLATE: " + str(len(batch))+"\n")
                f.write(str(batch)+"\n")
            batch_size = batch_size // 2