def forward(enc_inputs, dec_inputs):
"""前向网络
enc_inputs: [batch_size, src_len]
dec_inputs: [batch_size, trg_len]
"""
logits, _, _, _ = model(enc_inputs, dec_inputs[:, :-1], src_pad_idx, trg_pad_idx)
targets = dec_inputs[:, 1:].view(-1)
loss = loss_fn(logits, targets)
return loss
def forward(enc_inputs, dec_inputs):
"""前向网络
enc_inputs: [batch_size, src_len]
dec_inputs: [batch_size, trg_len]
"""
logits, _, _, _ = model(enc_inputs, dec_inputs[:, :-1], src_pad_idx, trg_pad_idx)