- class BaseModel(nn.Module):
- """
- Core trainable object in OpenNMT. Implements a trainable interface
- for a simple, generic encoder / decoder or decoder only model.
- """
- def __init__(self, encoder, decoder):
- super(BaseModel, self).__init__()
- def forward(self, src, tgt, lengths, bptt=False, with_align=False):
- """Forward propagate a `src` and `tgt` pair for training.
- Possible initialized with a beginning decoder state.
- Args:
- src (Tensor): A source sequence passed to encoder.
- typically for inputs this will be a padded `LongTensor`
- of size ``(len, batch, features)``. However, may be an
- image or other generic input depending on encoder.
- tgt (LongTensor): A target sequence passed to decoder.
- Size ``(tgt_len, batch, features)``.