Skip to content

Commit f531b35

Browse files
committed
Merge branch 'lm_workload_priya' of github.com:mlcommons/algorithmic-efficiency into lm_workload_priya
2 parents bbc114f + 3b31ad5 commit f531b35

File tree

3 files changed

+71
-22
lines changed

3 files changed

+71
-22
lines changed

algoperf/checkpoint_utils.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import os
8-
from typing import Sequence, Tuple
8+
from typing import Sequence, Tuple, Optional
99

1010
import numpy as np
1111
import torch
@@ -14,7 +14,8 @@
1414
from flax.training import checkpoints as flax_checkpoints
1515
from flax.training.checkpoints import latest_checkpoint
1616
from tensorflow.io import gfile # pytype: disable=import-error
17-
17+
import orbax.checkpoint as ocp
18+
from orbax.checkpoint.type_handlers import NumpyHandler
1819
from algoperf import spec
1920
from algoperf.pytorch_utils import pytorch_setup
2021

@@ -29,6 +30,48 @@
2930
int,
3031
]
3132

33+
class BoolHandler(NumpyHandler):
34+
"""
35+
An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler.
36+
It works by treating the scalar as a 0-dimensional array.
37+
"""
38+
39+
def typestr(self) -> str:
40+
"""Unique string identifier for this handler."""
41+
return 'np.bool_'
42+
43+
async def serialize(
44+
self,
45+
values: Sequence[np.bool_],
46+
infos: Sequence,
47+
args: Optional[Sequence[ocp.SaveArgs]] = None,
48+
):
49+
"""
50+
Serializes a sequence of np.bool_ scalars by first converting them
51+
to 0-dim numpy arrays and then calling the parent NumpyHandler.
52+
"""
53+
# Convert each scalar np.bool_ to a 0-dimensional np.ndarray
54+
array_values = [np.asarray(v, dtype=np.bool_) for v in values]
55+
# Use the parent class's robust serialization logic
56+
return await super().serialize(array_values, infos, args)
57+
58+
async def deserialize(
59+
self,
60+
infos: Sequence,
61+
args: Optional[Sequence[ocp.RestoreArgs]] = None,
62+
) -> Sequence[np.bool_]:
63+
"""
64+
Deserializes into a sequence of np.bool_ scalars by calling the
65+
parent handler and then converting the resulting 0-dim arrays.
66+
"""
67+
# Parent deserialize will return a sequence of 0-dimensional np.ndarray
68+
results = await super().deserialize(infos, args)
69+
70+
# Convert each 0-d array back to an np.bool_ scalar using .item()
71+
scalar_results = [np.bool_(r.item()) for r in results]
72+
return scalar_results
73+
74+
ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True)
3275

3376
def maybe_restore_checkpoint(
3477
framework: str,

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,19 @@ def model_fn(
8484

8585

8686
def compute_weighted_cross_entropy(
87-
self,
88-
logits: spec.Tensor,
89-
targets: spec.Tensor,
90-
weights: Optional[spec.Tensor] = None,
91-
label_smoothing: float = 0.1,
92-
) -> Dict[str, spec.Tensor]: # differentiable
87+
self,
88+
logits: spec.Tensor,
89+
targets: spec.Tensor,
90+
weights: Optional[spec.Tensor] = None,
91+
label_smoothing: float = 0.1,
92+
) -> Dict[str, spec.Tensor]: # differentiable
9393
"""Compute weighted cross entropy and entropy for log probs and targets.
94-
9594
Args:
9695
logits: [batch, length, num_classes] float array.
9796
targets: categorical targets [batch, length] int array.
9897
weights: array of shape [batch, length].
9998
label_smoothing: label smoothing constant, used to determine the on and off
10099
values.
101-
102100
Returns:
103101
{'summed': scalar summed loss, 'n_valid_examples': scalar number of
104102
valid examples in batch, 'per_example': 1-d array of per-example losses}
@@ -108,18 +106,26 @@ def compute_weighted_cross_entropy(
108106
f'Incorrect shapes. Got shape {logits.shape} logits and '
109107
f'{targets.shape} targets.'
110108
)
111-
smoothed_targets = optax.smooth_labels(
112-
common_utils.onehot(targets, self._vocab_size), label_smoothing
113-
)
114-
115-
per_example_losses = -jnp.sum(
116-
smoothed_targets * jax.nn.log_softmax(logits), axis=-1
117-
)
118-
if weights is None:
119-
weights = jnp.ones_like(targets)
120-
per_example_losses = jnp.where(weights, per_example_losses, 0.0)
109+
# Compute log probabilities
110+
log_probs = jax.nn.log_softmax(logits, axis=-1)
111+
# Extract log probability of the target class
112+
# Shape: [batch, length]
113+
target_log_probs = jnp.take_along_axis(
114+
log_probs,
115+
targets[..., None],
116+
axis=-1
117+
).squeeze(-1)
118+
# Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p)
119+
# The above formula is easy to derive from the definition of label smoothing and cross-entropy loss.
120+
confidence = 1.0 - label_smoothing
121+
smoothing_term = label_smoothing / self._vocab_size
122+
per_example_losses = -1.0 * (confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1))
123+
if weights is not None:
124+
per_example_losses = jnp.where(weights, per_example_losses, 0.0)
125+
n_valid_examples = weights.sum()
126+
else:
127+
n_valid_examples = targets.shape[0] * targets.shape[1]
121128
summed_loss = per_example_losses.sum()
122-
n_valid_examples = weights.sum()
123129
return {
124130
'summed': summed_loss,
125131
'n_valid_examples': n_valid_examples,

algoperf/workloads/lm/lm_pytorch/plainlm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ModelConfig:
1616
n_layers: int
1717
n_heads: int
1818
rmsnorm_eps: float = 1e-6
19-
tie_embeddings: bool = False
19+
tie_embeddings: bool = True
2020

2121

2222
class MLP(nn.Module):

0 commit comments

Comments
 (0)