Skip to content

Python bindings

Tatiana Likhomanenko edited this page Dec 12, 2019 · 9 revisions

ASG Loss

ASG loss is pytorch module (nn.Module) which supports CPU and CUDA backends, and it can be defined as

asg_loss = ASGLoss(ntokens, scale_mode).to(device)

where ntokens is the number of tokens predicted for each frame (number of classes), scale_mode is a scaling factor which can be:

NONE = 0, # no scaling
INPUT_SZ = 1, # scale to the input size
INPUT_SZ_SQRT = 2, # scale to the sqrt of the input size
TARGET_SZ = 3, # scale to the target size
TARGET_SZ_SQRT = 4, # scale to the sqrt of the target size

Beam-search decoder

Example how to define your own language model state

class LMStateNew(LMState):
    some_helpful_var = 1

    def __init__(self, some_helpful_var):
        super().__init__()
        self.some_helpful_var = some_helpful_var