Skip to content

Commit 617e1a3

Browse files
committed
add todo for pytorch _eval_batch cleanup
1 parent 2b162e8 commit 617e1a3

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def _eval_batch(self,
148148
if targets.dim() == 3: # one-hot
149149
loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1))
150150
else: # token IDs
151+
# TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch.
151152
loss = torch.nn.functional.cross_entropy(
152153
logits.view(-1, logits.size(-1)),
153154
targets.view(-1),

0 commit comments

Comments
 (0)