-
Notifications
You must be signed in to change notification settings - Fork 1k
Python bindings
Tatiana Likhomanenko edited this page Dec 12, 2019
·
9 revisions
ASG loss is pytorch module (nn.Module
) which supports CPU and CUDA backends, and it can be defined as
asg_loss = ASGLoss(ntokens, scale_mode).to(device)
where ntokens
is the number of tokens predicted for each frame (number of classes), scale_mode
is a scaling factor which can be:
NONE = 0, # no scaling
INPUT_SZ = 1, # scale to the input size
INPUT_SZ_SQRT = 2, # scale to the sqrt of the input size
TARGET_SZ = 3, # scale to the target size
TARGET_SZ_SQRT = 4, # scale to the sqrt of the target size
Example how to define your own language model state
class LMStateNew(LMState):
some_helpful_var = 1
def __init__(self, some_helpful_var):
super().__init__()
self.some_helpful_var = some_helpful_var