Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
1d81455
Merge pull request #847 from mlcommons/dev
priyakasimbeg Feb 27, 2025
da5f85a
first LM commit
Niccolo-Ajroldi Mar 11, 2025
a12a364
lm data pipeline
Niccolo-Ajroldi Mar 12, 2025
ca83ab8
testing
Niccolo-Ajroldi Mar 14, 2025
e3e78dc
LM workload tested torch pipeline
Niccolo-Ajroldi Mar 17, 2025
e619495
LM workload - fix torch tests
Niccolo-Ajroldi Mar 17, 2025
d8e9c56
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
6b4ff12
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
3c5c847
Stop tracking .gitignore
Niccolo-Ajroldi Mar 18, 2025
20d841b
Remove dev/ from repo, keep locally
Niccolo-Ajroldi Mar 18, 2025
f3ba059
fix comments
Niccolo-Ajroldi Mar 18, 2025
381451f
add class specifications
Niccolo-Ajroldi Mar 18, 2025
f111d2e
add workload LM info
Niccolo-Ajroldi Mar 18, 2025
808d398
restore data_utils.py tree map
Niccolo-Ajroldi Mar 18, 2025
35f8f89
fixed NFS bug
Niccolo-Ajroldi Mar 18, 2025
cbb6ee6
train/val split before concat
Niccolo-Ajroldi Mar 18, 2025
868987c
renamed datasets to avoid conflict with HF
Niccolo-Ajroldi Mar 19, 2025
8191f6d
Merge remote-tracking branch 'upstream/lm_workload' into lm_workload
Niccolo-Ajroldi Mar 19, 2025
dd59ded
renamed datasets to dataset
Niccolo-Ajroldi Mar 19, 2025
496b9c3
fix style
Niccolo-Ajroldi Mar 20, 2025
50989eb
fix formatting
Niccolo-Ajroldi Mar 20, 2025
5af0fdc
fix style
Niccolo-Ajroldi Mar 20, 2025
2683099
fix style
Niccolo-Ajroldi Mar 20, 2025
6b7ee29
fix yapf
Niccolo-Ajroldi Mar 20, 2025
46b645b
fix style
Niccolo-Ajroldi Mar 20, 2025
b3ae647
HF datasets pipeline
rka97 Mar 27, 2025
f095d4b
Testing with linear model
rka97 Mar 27, 2025
4189ae0
Merge branch 'jit_switch' into lm_workload
rka97 Mar 27, 2025
0c22f3d
lm workload with linear model
rka97 Apr 3, 2025
99c7b9b
add nanodo model
rka97 Apr 3, 2025
706d9f7
torch model
rka97 Apr 3, 2025
c335e34
lm workload dataset integration in jax
rka97 May 29, 2025
2d54365
lm workload dataset integration in jax
rka97 May 29, 2025
af8cce4
set package versions for transformers and datasets
priyakasimbeg Jun 5, 2025
d68c54e
use train_test_split method to shuffle and split fineweb-edu dataset
priyakasimbeg Jun 5, 2025
9737367
modifications to fwedu datasetup
priyakasimbeg Jun 9, 2025
1bf0750
rename fwedu data dir
priyakasimbeg Jun 9, 2025
a333391
fix
priyakasimbeg Jun 9, 2025
05dc4dd
add back batch mapping in tokenization for fwedu
priyakasimbeg Jun 9, 2025
b374cf8
debugging
priyakasimbeg Jun 10, 2025
c0c1e3c
debugging
priyakasimbeg Jun 10, 2025
f76dc39
debugging
priyakasimbeg Jun 10, 2025
e805fa7
use tfds to shuffle and split dataset
priyakasimbeg Jun 10, 2025
362cbda
Merge remote-tracking branch 'origin/dev' into lm_workload
rka97 Sep 11, 2025
c9e9abc
add command for fineweb-edu
priyakasimbeg Oct 2, 2025
e4323de
fix
priyakasimbeg Oct 2, 2025
f0c6e75
update calls to sharing utils
priyakasimbeg Oct 3, 2025
f4ffbe7
Fix torch sharding issue, update input pipeline and workload classes …
rka97 Oct 6, 2025
5c85c7e
test working, lm workload training not working (debugging)
rka97 Oct 6, 2025
a59dfda
updates to input_pipeline and model spec
priyakasimbeg Oct 6, 2025
1c3cb66
add defaults for lm workload
priyakasimbeg Oct 6, 2025
af91b12
refactor eval pipeline and loss fn for lm
priyakasimbeg Oct 7, 2025
6b55adf
refactor evaluation pipeline for lm
priyakasimbeg Oct 7, 2025
210d671
remove temporary flag for hlo dumps
priyakasimbeg Oct 7, 2025
0ad7788
fix in workload target condition check
priyakasimbeg Oct 7, 2025
01921d5
fix in mlp for glu
priyakasimbeg Oct 8, 2025
e420450
Fix OOM error in weighted cross entropy calculation
rka97 Oct 10, 2025
3b31ad5
fix issue with checkpointing bool
rka97 Oct 10, 2025
bbc114f
increase buffer size
priyakasimbeg Oct 10, 2025
f531b35
Merge branch 'lm_workload_priya' of github.com:mlcommons/algorithmic-…
priyakasimbeg Oct 10, 2025
2b162e8
remove _eval_batch from jax workload
priyakasimbeg Oct 10, 2025
617e1a3
add todo for pytorch _eval_batch cleanup
priyakasimbeg Oct 10, 2025
bebc80a
Merge pull request #891 from mlcommons/lm_workload_priya
rka97 Oct 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ scoring/plots/
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

algoperf/_version.py
algoperf/_version.py
47 changes: 45 additions & 2 deletions algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import os
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Optional

import numpy as np
import torch
Expand All @@ -14,7 +14,8 @@
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
from tensorflow.io import gfile # pytype: disable=import-error

import orbax.checkpoint as ocp
from orbax.checkpoint.type_handlers import NumpyHandler
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup

Expand All @@ -29,6 +30,48 @@
int,
]

class BoolHandler(NumpyHandler):
"""
An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler.
It works by treating the scalar as a 0-dimensional array.
"""

def typestr(self) -> str:
"""Unique string identifier for this handler."""
return 'np.bool_'

async def serialize(
self,
values: Sequence[np.bool_],
infos: Sequence,
args: Optional[Sequence[ocp.SaveArgs]] = None,
):
"""
Serializes a sequence of np.bool_ scalars by first converting them
to 0-dim numpy arrays and then calling the parent NumpyHandler.
"""
# Convert each scalar np.bool_ to a 0-dimensional np.ndarray
array_values = [np.asarray(v, dtype=np.bool_) for v in values]
# Use the parent class's robust serialization logic
return await super().serialize(array_values, infos, args)

async def deserialize(
self,
infos: Sequence,
args: Optional[Sequence[ocp.RestoreArgs]] = None,
) -> Sequence[np.bool_]:
"""
Deserializes into a sequence of np.bool_ scalars by calling the
parent handler and then converting the resulting 0-dim arrays.
"""
# Parent deserialize will return a sequence of 0-dimensional np.ndarray
results = await super().deserialize(infos, args)

# Convert each 0-d array back to an np.bool_ scalar using .item()
scalar_results = [np.bool_(r.item()) for r in results]
return scalar_results

ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True)

def maybe_restore_checkpoint(
framework: str,
Expand Down
2 changes: 2 additions & 0 deletions algoperf/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def pytorch_param_types(
param_types[name] = spec.ParameterType.ATTENTION_BIAS
elif 'in_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'qkv' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'kv_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_KV
elif 'k_proj' in name or 'key' in name:
Expand Down
Empty file.
138 changes: 138 additions & 0 deletions algoperf/workloads/lm/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Input pipeline for a LM dataset."""

import functools
import os
from typing import Optional

import jax
import numpy as np
import tensorflow as tf

from algoperf import data_utils

AUTOTUNE = tf.data.experimental.AUTOTUNE
PAD_ID = tf.constant(-1, dtype=tf.int64)

TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'}

SEQUENCE_LENGTH = 1024
MAX_CORPUS_CHARS = 1_000_000_000
SHUFFLE_BUFFER_SIZE = 100_000
VOCAB_SIZE = 50_257


def batch_with_padding(
dataset: tf.data.Dataset,
batch_size,
padded_shapes=None,
padding_id=PAD_ID,
):
"""Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size.

Args:
dataset: tf.data.Dataset
batch_size: batch size of resulting batched dataset
padded_shapes: shapes of the padded batches
padding_id: value for padding, for elements in new batch

Returns:
"""
batched_dataset = dataset.batch(batch_size, drop_remainder=False)

# tf.data.Dataset.padded.batch pads elements in the batch so we call it
# again with batch_size=1 to pad each element in original batch.
padded_batched_dataset = batched_dataset.padded_batch(
1, padded_shapes=padded_shapes, padding_values=padding_id
)

# Remove extra dimension resulting from the batch_size=1.
padded_batched_dataset = padded_batched_dataset.unbatch()

return padded_batched_dataset


def get_data_iter(data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,):

ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches)

it = map(
functools.partial(
data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
),
ds,
)

return iter(it)

def get_lm_dataset(
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
):
"""Load preprocessed TF dataset."""
if split not in TFDS_SPLIT_NAME:
raise NotImplementedError

shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1)

data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split])
tokens_ds = tf.data.Dataset.load(data_dir)

# tokens
tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices)

# sequences
sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True)

# get inputs and outputs
sequences_ds = sequences_ds.map(
lambda x: {
'inputs': x['input_ids'][:SEQUENCE_LENGTH],
'targets': x['input_ids'][1:],
},
num_parallel_calls=AUTOTUNE,
)

# batch
if split == 'train':
shuffled_sequences_ds = sequences_ds.shuffle(
SHUFFLE_BUFFER_SIZE, seed=shuffle_seed
)
repeated_sequences_dataset = shuffled_sequences_ds.repeat()
ds = repeated_sequences_dataset.batch(
global_batch_size, drop_remainder=False
).prefetch(tf.data.experimental.AUTOTUNE)
elif split == 'eval_train':
ds = batch_with_padding(
sequences_ds,
global_batch_size,
padded_shapes={
'inputs': (global_batch_size, None),
'targets': (global_batch_size, None),
},
)
ds = ds.map(lambda x: {'inputs': x['inputs'],
'targets': x['targets'],
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation
elif split == 'validation':
ds = batch_with_padding(
sequences_ds,
global_batch_size,
padded_shapes={
'inputs': (global_batch_size, None),
'targets': (global_batch_size, None),
},
)
ds = ds.map(lambda x: {'inputs': x['inputs'],
'targets': x['targets'],
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size

return ds
Empty file.
19 changes: 19 additions & 0 deletions algoperf/workloads/lm/lm_jax/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from flax import linen as nn
import jax.numpy as jnp

class LinearModel(nn.Module):
vocab_size: int

@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(
10,
kernel_init=nn.initializers.normal(0.02),
bias_init=nn.initializers.zeros
)(inputs)
return nn.Dense(
self.vocab_size,
kernel_init=nn.initializers.normal(0.02),
bias_init=nn.initializers.zeros,
name="output"
)(x)
Loading
Loading