We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2b162e8 commit 617e1a3Copy full SHA for 617e1a3
algoperf/workloads/lm/lm_pytorch/workload.py
@@ -148,6 +148,7 @@ def _eval_batch(self,
148
if targets.dim() == 3: # one-hot
149
loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1))
150
else: # token IDs
151
+ # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch.
152
loss = torch.nn.functional.cross_entropy(
153
logits.view(-1, logits.size(-1)),
154
targets.view(-1),
0 commit comments