From da5f85a7c878f0399c7b8a5d2fcfb9d729e567ea Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 11 Mar 2025 15:46:49 +0100 Subject: [PATCH 01/63] first LM commit --- algoperf/workloads/lm/__init__.py | 0 algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++++++ algoperf/workloads/lm/input_pipeline.py | 82 ++++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/__init__.py | 0 algoperf/workloads/lm/lm_pytorch/workload.py | 36 +++++++++ algoperf/workloads/lm/test_01.py | 22 ++++++ algoperf/workloads/lm/test_input_pipeline.py | 68 ++++++++++++++++ algoperf/workloads/lm/workload.py | 66 ++++++++++++++++ 8 files changed, 316 insertions(+) create mode 100644 algoperf/workloads/lm/__init__.py create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/input_pipeline.py create mode 100644 algoperf/workloads/lm/lm_pytorch/__init__.py create mode 100644 algoperf/workloads/lm/lm_pytorch/workload.py create mode 100644 algoperf/workloads/lm/test_01.py create mode 100644 algoperf/workloads/lm/test_input_pipeline.py create mode 100644 algoperf/workloads/lm/workload.py diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py new file mode 100644 index 000000000..7424dd6d5 --- /dev/null +++ b/algoperf/workloads/lm/input_pipeline.py @@ -0,0 +1,82 @@ +"""Input pipeline for a LM dataset.""" +import functools +import os + +from datasets import Dataset, load_from_disk +from typing import Dict, List, Optional, Union + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup + +RANK = pytorch_setup()[1] +# Avoid multithreading in all processes but the first (rank 0). +AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None + + +def get_lm_dataset(data_rng, + split: str, + data_dir: str, + is_training: bool, + vocab_size: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + vocab_path: Optional[str] = None): + """Load HF dataset and return a TF dataset.""" + + dataset_path = os.path.join(data_dir, split) + dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + + is_training = split == "train" + shuffle = split in ['train', 'eval_train'] + + def tf_generator(): + """Generates data in a TensorFlow-friendly format.""" + for example in dataset: + yield { + "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + } + + # Create a TensorFlow dataset from the generator function + ds = tf.data.Dataset.from_generator( + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), + } + ) + + # Avoid creating too many threads when using PyTorch DDP. + if RANK != 0: + options = tf.data.Options() + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + + if shuffle: + print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") + ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) + + if is_training: + ds = ds.repeat() + + # Batch the dataset, ensuring the last batch is dropped if not full during training + ds = ds.batch(global_batch_size, drop_remainder=is_training) + ds = ds.prefetch(AUTOTUNE) + + # Limit the dataset to a fixed number of batches if `num_batches` is specified + if num_batches: + ds = ds.take(num_batches) + + # Shard the dataset across multiple GPUs/TPUs if necessary + ds = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) + + return ds \ No newline at end of file diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/lm/lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py new file mode 100644 index 000000000..904657b1d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -0,0 +1,36 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from typing import Any, Dict, Optional, Tuple + +from absl import logging +import jax +import tensorflow as tf +import torch +import torch.distributed as dist +from torch.nn import DataParallel as DP +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn(): + pass + + def model_fn(): + pass + + def _build_input_queue(): + pass + + def eval_step(): + pass diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py new file mode 100644 index 000000000..e33ddf3e7 --- /dev/null +++ b/algoperf/workloads/lm/test_01.py @@ -0,0 +1,22 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + +tf_seed = SEED + +# Load the dataset +ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, +) diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/test_input_pipeline.py new file mode 100644 index 000000000..47c11969f --- /dev/null +++ b/algoperf/workloads/lm/test_input_pipeline.py @@ -0,0 +1,68 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + + +def test_tf_dataset(): + """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" + + print(f"Loading dataset from: {DATASET_PATH}") + + tf_seed = SEED + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + + print("Testing TensorFlow Dataset Output...") + for batch in ds.take(2): # Take two batches to test + print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection + print("Targets:", batch["targets"].numpy()) + +def test_pytorch_dataloader(): + """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" + + # Use the same TensorFlow-compatible seed + tf_seed = tf.constant(SEED, dtype=tf.int64) + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, + global_batch_size=BATCH_SIZE, + ) + + def _input_queue_generator(): + """Generator that converts TF dataset batches to PyTorch tensors.""" + for batch in iter(ds): + batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors + yield batch + + dataloader = _input_queue_generator() + + print("\nTesting PyTorch DataLoader Output...") + for _ in range(2): # Take two batches + batch = next(dataloader) + print("Inputs:", batch["inputs"]) + print("Targets:", batch["targets"]) + +# Run tests +if __name__ == "__main__": + test_tf_dataset() + test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py new file mode 100644 index 000000000..d070cabec --- /dev/null +++ b/algoperf/workloads/lm/workload.py @@ -0,0 +1,66 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Optional, Tuple + +import jax +import numpy as np +import torch + +from algoperf import spec +from algoperf.workloads.lm import input_pipeline + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """A LM workload.""" + + _vocab_size: int = 32000 + + def __init__(self) -> None: + super().__init__() + self._tokenizer = None + + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + is_training = split == 'train' + ds, self._tokenizer = input_pipeline.get_lm_dataset( + data_rng, + split, + data_dir, + is_training=is_training, + vocab_size=self._vocab_size, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + + for batch in iter(ds): + yield batch + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the loss function at (label_batch, logits_batch).""" + pass \ No newline at end of file From a12a36404ce907c8e50e67c8e4a5eb25baa9a2f3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 12 Mar 2025 15:49:04 +0100 Subject: [PATCH 02/63] lm data pipeline --- algoperf/workloads/lm/input_pipeline.py | 11 +-- algoperf/workloads/lm/test_01.py | 96 +++++++++++++++++++++---- datasets/dataset_setup.py | 96 +++++++++++++++++++++++++ datasets/lm_preprocess.py | 0 4 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 datasets/lm_preprocess.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7424dd6d5..a14cebeda 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from datasets import Dataset, load_from_disk from typing import Dict, List, Optional, Union +import jax import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -17,7 +18,7 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None -def get_lm_dataset(data_rng, +def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, is_training: bool, @@ -37,11 +38,12 @@ def get_lm_dataset(data_rng, def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: + input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), } - + # Create a TensorFlow dataset from the generator function ds = tf.data.Dataset.from_generator( tf_generator, @@ -58,7 +60,6 @@ def tf_generator(): ds = ds.with_options(options) if shuffle: - print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) if is_training: diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py index e33ddf3e7..977fae11a 100644 --- a/algoperf/workloads/lm/test_01.py +++ b/algoperf/workloads/lm/test_01.py @@ -1,22 +1,92 @@ + import os +import numpy as np import tensorflow as tf import torch + from datasets import load_from_disk +from absl import app +from absl import flags +from absl import logging + +from algoperf.profiler import PassThroughProfiler +from algoperf import random_utils as prng +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +# (nico) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +flags.DEFINE_enum( + 'framework', + None, + enum_values=['jax', 'pytorch'], + help='Whether to use Jax or Pytorch for the submission. Controls among ' + 'other things if the Jax or Numpy RNG library is used for RNG.') + +FLAGS = flags.FLAGS +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - -tf_seed = SEED - -# Load the dataset -ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, -) +RNG_SEED = 1996 # Fixed random seed for reproducibility + + +def main(_): + profiler = PassThroughProfiler() + if FLAGS.framework == 'pytorch': + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + + rng = prng.PRNGKey(RNG_SEED) + data_rng, _, _, _ = prng.split(rng, 4) + + print(f"data_rng = {data_rng}") + + # Load the dataset + ds = get_lm_dataset( + data_rng=data_rng, + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + # Check if `ds` acts as a generator + if hasattr(ds, '__iter__'): + print("Dataset is an iterable/generator.") + + # Fetch first batch + try: + first_batch = next(iter(ds)) + print(f"Successfully retrieved first batch.") + except Exception as e: + print(f"Error retrieving first batch: {e}") + return + + # Print structure of a batch + print(f"First batch keys: {first_batch.keys()}") + print(f"First batch shapes:") + for key, value in first_batch.items(): + print(f" - {key}: {value.shape} (dtype: {value.dtype})") + + # Validate batch dimensions + assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" + assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" + assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" + + print(f"Dataset is correctly batched and structured.") + print(f"Test completed successfully.") + +if __name__ == '__main__': + flags.mark_flag_as_required('framework') + app.run(main) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index efe923dbe..14dd24545 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,13 +76,21 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer +from datasets import lm_preprocess +import datasets as hf_datasets +# from datasets import load_dataset, Dataset +from transformers import AutoTokenizer + +import math import functools +import itertools import os import shutil import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -126,6 +134,9 @@ flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -699,6 +710,86 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) +def download_finewebedu(data_dir, tmp_dir): + """Download FineWebEdu-10B.""" + + # data_dir = "/fast/najroldi/data" + + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") + data_dir = os.path.join(data_dir, 'finewebedu') + + _maybe_mkdir(tmp_dir) + _maybe_mkdir(data_dir) + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + # cache_dir=tmp_dir + ) + + ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + + seq_len = 2048 + max_seq_length = seq_len+1 + map_setup = dict(batched=True, batch_size=1024, num_proc=8) + + # Tokenize + tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] + return tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False + ) + + tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + tokenized_dataset = ds.map( + tokenize, + remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', + 'language_score', 'token_count', 'score', 'int_score'], + **map_setup + ) + tokenizer.model_max_length = seq_len + + # Concat in chunks of max_seq_len + def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Concatenate text and generate chunks of max_seq_length""" + concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + if total_length >= max_seq_length: + total_length = (total_length // max_seq_length) * max_seq_length + result = { + k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] + for k, t in concatenated_examples.items() + } + return result + + lm_dataset = tokenized_dataset.map( + concat_chunck, + **map_setup + ) + + n_tokens = len(lm_dataset) * max_seq_length + logging.info(f"Number of tokens in dataset: {n_tokens:_}") + + # Split dataset into training and validation sets + # TODO: avoid (single doc) contamination between train and val + VAL_TOKENS = 10_000_000 + val_samples = VAL_TOKENS // max_seq_length + 1 + val_dataset = lm_dataset.select(range(val_samples)) + train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + + # Save datasets + train_dataset.save_to_disk(os.path.join(data_dir, f"train")) + val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -781,6 +872,11 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + if not FLAGS.skip_download: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/lm_preprocess.py b/datasets/lm_preprocess.py new file mode 100644 index 000000000..e69de29bb From ca83ab8954a9e164dc538cb4749847812ee0e032 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 14 Mar 2025 11:31:08 +0100 Subject: [PATCH 03/63] testing --- algoperf/workloads/lm/{ => dev}/test_01.py | 0 .../lm/{ => dev}/test_input_pipeline.py | 0 algoperf/workloads/lm/input_pipeline.py | 37 +++++---- .../workloads/lm/lm_jax/__init__.py | 0 algoperf/workloads/lm/lm_jax/workload.py | 20 +++++ algoperf/workloads/lm/lm_pytorch/workload.py | 56 ++++++++++++- algoperf/workloads/lm/test.py | 37 +++++++++ algoperf/workloads/lm/workload.py | 80 ++++++++++++++----- datasets/dataset_setup.py | 25 ++++-- 9 files changed, 211 insertions(+), 44 deletions(-) rename algoperf/workloads/lm/{ => dev}/test_01.py (100%) rename algoperf/workloads/lm/{ => dev}/test_input_pipeline.py (100%) rename datasets/lm_preprocess.py => algoperf/workloads/lm/lm_jax/__init__.py (100%) create mode 100644 algoperf/workloads/lm/lm_jax/workload.py create mode 100644 algoperf/workloads/lm/test.py diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/dev/test_01.py similarity index 100% rename from algoperf/workloads/lm/test_01.py rename to algoperf/workloads/lm/dev/test_01.py diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py similarity index 100% rename from algoperf/workloads/lm/test_input_pipeline.py rename to algoperf/workloads/lm/dev/test_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index a14cebeda..f0024e4a6 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -15,6 +15,10 @@ RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). +# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# automatic optimization (AUTOTUNE), while other processes disable it (None). +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal +# number of elements to prefetch or parallelize for dataset operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -30,34 +34,36 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + dataset = load_from_disk(dataset_path) is_training = split == "train" shuffle = split in ['train', 'eval_train'] + dataset.set_format("tensorflow") # tf.int64 + def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: - input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } - # Create a TensorFlow dataset from the generator function + # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + } + ) # Avoid creating too many threads when using PyTorch DDP. - if RANK != 0: + # Limits TensorFlow's threading for non-primary processes (RANK != 0) + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) + options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread + ds = ds.with_options(options) # apply threading restrictions if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -66,6 +72,9 @@ def tf_generator(): ds = ds.repeat() # Batch the dataset, ensuring the last batch is dropped if not full during training + # i.e. it groups consecutive elements into fixed-size chunks. + # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), + # improving efficiency and parallelism in training ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) diff --git a/datasets/lm_preprocess.py b/algoperf/workloads/lm/lm_jax/__init__.py similarity index 100% rename from datasets/lm_preprocess.py rename to algoperf/workloads/lm/lm_jax/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py new file mode 100644 index 000000000..4cdb42409 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -0,0 +1,20 @@ +"""LM workload implemented in Jax.""" + +import functools +from typing import Dict, Optional, Tuple + +from flax import jax_utils +import jax +import jax.numpy as jnp +import numpy as np + +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + + @property + def eval_batch_size(self) -> int: + return 131_072 diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 904657b1d..9ee21ccb6 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -29,8 +29,58 @@ def init_model_fn(): def model_fn(): pass - def _build_input_queue(): - pass - + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + per_device_batch_size = int(global_batch_size / N_GPUS) + + # Only create and iterate over tf input pipeline in one Python process to + # avoid creating too many threads. + if RANK == 0: + np_iter = super()._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + while True: + if RANK == 0: + batch = next(np_iter) + inputs = torch.as_tensor( + batch['inputs'], dtype=torch.float32, device=DEVICE) + targets = torch.as_tensor( + batch['targets'], dtype=torch.float32, device=DEVICE) + # Send batch to other devices when using DDP. + if USE_PYTORCH_DDP: + dist.broadcast(inputs, src=0) + inputs = inputs[0] # TODO: check + dist.broadcast(targets, src=0) + targets = targets[0] # TODO: check + else: + batch = {} + inputs = torch.empty((N_GPUS, per_device_batch_size, 39), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(inputs, src=0) + inputs = inputs[RANK] + targets = torch.empty((N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(targets, src=0) + targets = targets[RANK] + + batch = { + 'inputs': inputs, + 'targets': targets, + # 'weights': weights, + } + yield batch + + def eval_step(): pass diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/test.py new file mode 100644 index 000000000..7e693d0af --- /dev/null +++ b/algoperf/workloads/lm/test.py @@ -0,0 +1,37 @@ +""" +Test data pipaline in JAX and PyTorch. + +Instantiate a workload and loops over the input queue. +""" + +import jax +import numpy as np +import torch + +import algoperf.workloads.lm.lm_jax.workload as lm_jax +# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch + + +data_rng = jax.random.PRNGKey(0) +split = 'train' +data_dir = "/fast/najroldi/data/finewebedu" +global_batch_size = 8 +num_batches = 10 +repeat_final_dataset = False + +# ------------------------------------------------------------------------------ +# JAX +# ------------------------------------------------------------------------------ + +# 1 GPU +workload = lm_jax.LmWorkload() + +input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + +next(input_queue) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index d070cabec..63d2c707e 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -32,7 +32,7 @@ def _build_input_queue(self, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): is_training = split == 'train' - ds, self._tokenizer = input_pipeline.get_lm_dataset( + ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, @@ -41,26 +41,66 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) - + for batch in iter(ds): yield batch - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the loss function at (label_batch, logits_batch).""" + def _eval_model_on_split(): + pass + + def eval_period_time_sec(): + pass + + def has_reached_test_target(): + pass + + def has_reached_validation_target(): + pass + + def init_model_fn(): + pass + + def is_output_params(): + pass + + def loss_fn(): + pass + + def loss_type(): + pass + + def max_allowed_runtime_sec(): + pass + + def model_fn(): + pass + + def num_eval_train_examples(): + pass + + def num_test_examples(): + pass + + def num_train_examples(): + pass + + def num_validation_examples(): + pass + + def step_hint(): + pass + + def test_target_value(): + pass + + def train_mean(): + pass + + def train_stddev(): + pass + + def validation_target_value(): + pass + + def target_metric_name(): pass \ No newline at end of file diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 14dd24545..aab793832 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,10 +76,8 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer -from datasets import lm_preprocess import datasets as hf_datasets -# from datasets import load_dataset, Dataset from transformers import AutoTokenizer import math @@ -721,6 +719,9 @@ def download_finewebedu(data_dir, tmp_dir): _maybe_mkdir(tmp_dir) _maybe_mkdir(data_dir) + # Use local disk instead of NFS for temp storage + os.environ["TMPDIR"] = tmp_dir + ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', @@ -745,7 +746,6 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_special_tokens_mask=False, return_attention_mask=False ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization tokenized_dataset = ds.map( tokenize, @@ -754,8 +754,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: **map_setup ) tokenizer.model_max_length = seq_len + + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + from datasets import load_from_disk + tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len + # TODO: this might take to much memory + # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 + # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples + # NOTE: the current approach leads to data loss at batch boundaries, + # but concatenation *cannot* happen if batched=False, + # because concat_chunck relies on processing multiple examples at once. + # (2) loss happening because of nproc>1: potentially losing the last tokens in each process + # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -767,13 +780,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck, + concat_chunck,\ **map_setup ) - - n_tokens = len(lm_dataset) * max_seq_length + n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets From e3e78dc6443c5485af64bfe986951f72d9754f99 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:18:41 +0100 Subject: [PATCH 04/63] LM workload tested torch pipeline --- algoperf/data_utils.py | 2 +- .../lm/dev/test_build_input_queue_torch.py | 80 +++++++++++++++++++ .../workloads/lm/{test.py => dev/test_jax.py} | 19 ++++- algoperf/workloads/lm/input_pipeline.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 5 +- algoperf/workloads/lm/lm_pytorch/workload.py | 68 +++++++++------- algoperf/workloads/lm/workload.py | 7 +- submission_runner.py | 2 +- 8 files changed, 146 insertions(+), 40 deletions(-) create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py rename algoperf/workloads/lm/{test.py => dev/test_jax.py} (63%) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..068c21c03 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree.map(_prepare, batch) + return jax.tree_util.tree_map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py new file mode 100644 index 000000000..86b1ca6b7 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -0,0 +1,80 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + +n_gpus = max(N_GPUS, jax.local_device_count()) + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + # batch = next(input_queue) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # Start test. + for _ in range(100): + + batch = next(input_queue) + assert type(batch) == dict + + assert 'inputs' in batch + assert 'targets' in batch + + assert type(batch['inputs']) == torch.Tensor + assert type(batch['targets']) == torch.Tensor + + assert batch['inputs'].dtype == dtype + assert batch['targets'].dtype == dtype + + assert batch['inputs'].shape == (local_batch_size, seq_len) + assert batch['targets'].shape == (local_batch_size, seq_len) + + sync_ddp() + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/dev/test_jax.py similarity index 63% rename from algoperf/workloads/lm/test.py rename to algoperf/workloads/lm/dev/test_jax.py index 7e693d0af..4ba3de631 100644 --- a/algoperf/workloads/lm/test.py +++ b/algoperf/workloads/lm/dev/test_jax.py @@ -15,6 +15,7 @@ data_rng = jax.random.PRNGKey(0) split = 'train' data_dir = "/fast/najroldi/data/finewebedu" +seq_len = 2048 global_batch_size = 8 num_batches = 10 repeat_final_dataset = False @@ -34,4 +35,20 @@ num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) -next(input_queue) +batch = next(input_queue) +assert type(batch) == dict + +assert 'inputs' in batch +assert 'targets' in batch + +assert type(batch['inputs']) == np.ndarray +assert type(batch['targets']) == np.ndarray + +assert batch['inputs'].dtype == np.int64 +assert batch['targets'].dtype == np.int64 + +assert batch['inputs'].shape == (1, global_batch_size, seq_len) +assert batch['targets'].shape == (1, global_batch_size, seq_len) + +print(f"JAX devices = {jax.devices()}") +print("1") diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index f0024e4a6..e74490a16 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -25,7 +25,6 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - is_training: bool, vocab_size: int, global_batch_size: int, num_batches: Optional[int] = None, @@ -39,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 + dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 4cdb42409..773f8c54c 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -14,7 +14,4 @@ class LmWorkload(BaseLmWorkload): - - @property - def eval_batch_size(self) -> int: - return 131_072 + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 9ee21ccb6..0ff7884c7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,7 +1,7 @@ """LM workload implemented in PyTorch.""" import contextlib -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple from absl import logging import jax @@ -22,12 +22,6 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" - - def init_model_fn(): - pass - - def model_fn(): - pass def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -35,8 +29,12 @@ def _build_input_queue(self, data_dir: str, global_batch_size: int, num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) + + seq_len = 2048 # TODO: define it somewehere else + DTYPE = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -48,36 +46,50 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) + weights = None + while True: + # Only iterate over tf input pipeline in one Python process to + # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) - inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) - targets = torch.as_tensor( - batch['targets'], dtype=torch.float32, device=DEVICE) + batch = next(np_iter) # pylint: disable=stop-iteration-return + inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: - dist.broadcast(inputs, src=0) - inputs = inputs[0] # TODO: check - dist.broadcast(targets, src=0) - targets = targets[0] # TODO: check + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + # We don't broadcast the shard for RANK 0. + dist.broadcast(inputs[1:], src=0) + dist.broadcast(targets[1:], src=0) + + # RANK 0 extracts his shard. If not DDP, this just flattens. + inputs, targets = inputs[0], targets[0] + else: - batch = {} - inputs = torch.empty((N_GPUS, per_device_batch_size, 39), - dtype=torch.float32, - device=DEVICE) + # Receive batch from rank 0. + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + + # N_GPUS - 1 since we don't broadcast the shard for RANK 0. + inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) dist.broadcast(inputs, src=0) - inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) dist.broadcast(targets, src=0) - targets = targets[RANK] - + # RANK - 1 since we don't broadcast the shard for RANK 0. + inputs, targets = inputs[RANK-1], targets[RANK-1] + + if weights is None: + weights = torch.ones(per_device_batch_size, device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, - # 'weights': weights, + 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 63d2c707e..7b1313dd7 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -31,12 +31,10 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - is_training = split == 'train' ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, - is_training=is_training, vocab_size=self._vocab_size, global_batch_size=global_batch_size, num_batches=num_batches, @@ -103,4 +101,7 @@ def validation_target_value(): pass def target_metric_name(): - pass \ No newline at end of file + pass + + def eval_batch_size(): + pass diff --git a/submission_runner.py b/submission_runner.py index a2521e77b..6fac50d99 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ From e6194950fc524793906127f09b330a8329ad079f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:34:10 +0100 Subject: [PATCH 05/63] LM workload - fix torch tests --- .../lm/dev/test_build_input_queue_torch.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py index 86b1ca6b7..66205d091 100644 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -41,30 +41,33 @@ def test_dataloader_torch(): data_dir=data_dir, global_batch_size=global_batch_size) - # batch = next(input_queue) - print(f"RANK {RANK} of {N_GPUS}") sync_ddp() # Start test. for _ in range(100): - + batch = next(input_queue) - assert type(batch) == dict + assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch - assert type(batch['inputs']) == torch.Tensor - assert type(batch['targets']) == torch.Tensor + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype - assert batch['inputs'].dtype == dtype - assert batch['targets'].dtype == dtype + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) - assert batch['inputs'].shape == (local_batch_size, seq_len) - assert batch['targets'].shape == (local_batch_size, seq_len) - - sync_ddp() + assert torch.equal(inputs[:,1:], targets[:,:-1]) print(f"=== ALL TEST PASSED ===") From d8e9c56738de817e561e79cffee638ab7197eaed Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:36 +0100 Subject: [PATCH 06/63] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 --------- algoperf/workloads/lm/dev/test_01.py | 92 ------------------- .../lm/dev/test_build_input_queue_torch.py | 83 ----------------- .../workloads/lm/dev/test_input_pipeline.py | 68 -------------- algoperf/workloads/lm/dev/test_jax.py | 54 ----------- 5 files changed, 339 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_01.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py delete mode 100644 algoperf/workloads/lm/dev/test_input_pipeline.py delete mode 100644 algoperf/workloads/lm/dev/test_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_01.py b/algoperf/workloads/lm/dev/test_01.py deleted file mode 100644 index 977fae11a..000000000 --- a/algoperf/workloads/lm/dev/test_01.py +++ /dev/null @@ -1,92 +0,0 @@ - -import os -import numpy as np -import tensorflow as tf -import torch - -from datasets import load_from_disk - -from absl import app -from absl import flags -from absl import logging - -from algoperf.profiler import PassThroughProfiler -from algoperf import random_utils as prng -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - - -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' -# (nico) -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - -flags.DEFINE_enum( - 'framework', - None, - enum_values=['jax', 'pytorch'], - help='Whether to use Jax or Pytorch for the submission. Controls among ' - 'other things if the Jax or Numpy RNG library is used for RNG.') - -FLAGS = flags.FLAGS -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -RNG_SEED = 1996 # Fixed random seed for reproducibility - - -def main(_): - profiler = PassThroughProfiler() - if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - - rng = prng.PRNGKey(RNG_SEED) - data_rng, _, _, _ = prng.split(rng, 4) - - print(f"data_rng = {data_rng}") - - # Load the dataset - ds = get_lm_dataset( - data_rng=data_rng, - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - # Check if `ds` acts as a generator - if hasattr(ds, '__iter__'): - print("Dataset is an iterable/generator.") - - # Fetch first batch - try: - first_batch = next(iter(ds)) - print(f"Successfully retrieved first batch.") - except Exception as e: - print(f"Error retrieving first batch: {e}") - return - - # Print structure of a batch - print(f"First batch keys: {first_batch.keys()}") - print(f"First batch shapes:") - for key, value in first_batch.items(): - print(f" - {key}: {value.shape} (dtype: {value.dtype})") - - # Validate batch dimensions - assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" - assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" - assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" - - print(f"Dataset is correctly batched and structured.") - print(f"Test completed successfully.") - -if __name__ == '__main__': - flags.mark_flag_as_required('framework') - app.run(main) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py deleted file mode 100644 index 66205d091..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ /dev/null @@ -1,83 +0,0 @@ - -import jax -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload - -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - -n_gpus = max(N_GPUS, jax.local_device_count()) - -def sync_ddp(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def test_dataloader_torch(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = torch.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - print(f"RANK {RANK} of {N_GPUS}") - sync_ddp() - - # Start test. - for _ in range(100): - - batch = next(input_queue) - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - test_dataloader_torch() - - -if __name__ == '__main__': - main() - diff --git a/algoperf/workloads/lm/dev/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py deleted file mode 100644 index 47c11969f..000000000 --- a/algoperf/workloads/lm/dev/test_input_pipeline.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import tensorflow as tf -import torch -from datasets import load_from_disk - -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - - -def test_tf_dataset(): - """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" - - print(f"Loading dataset from: {DATASET_PATH}") - - tf_seed = SEED - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - - print("Testing TensorFlow Dataset Output...") - for batch in ds.take(2): # Take two batches to test - print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection - print("Targets:", batch["targets"].numpy()) - -def test_pytorch_dataloader(): - """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" - - # Use the same TensorFlow-compatible seed - tf_seed = tf.constant(SEED, dtype=tf.int64) - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, - global_batch_size=BATCH_SIZE, - ) - - def _input_queue_generator(): - """Generator that converts TF dataset batches to PyTorch tensors.""" - for batch in iter(ds): - batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors - yield batch - - dataloader = _input_queue_generator() - - print("\nTesting PyTorch DataLoader Output...") - for _ in range(2): # Take two batches - batch = next(dataloader) - print("Inputs:", batch["inputs"]) - print("Targets:", batch["targets"]) - -# Run tests -if __name__ == "__main__": - test_tf_dataset() - test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/dev/test_jax.py b/algoperf/workloads/lm/dev/test_jax.py deleted file mode 100644 index 4ba3de631..000000000 --- a/algoperf/workloads/lm/dev/test_jax.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Test data pipaline in JAX and PyTorch. - -Instantiate a workload and loops over the input queue. -""" - -import jax -import numpy as np -import torch - -import algoperf.workloads.lm.lm_jax.workload as lm_jax -# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch - - -data_rng = jax.random.PRNGKey(0) -split = 'train' -data_dir = "/fast/najroldi/data/finewebedu" -seq_len = 2048 -global_batch_size = 8 -num_batches = 10 -repeat_final_dataset = False - -# ------------------------------------------------------------------------------ -# JAX -# ------------------------------------------------------------------------------ - -# 1 GPU -workload = lm_jax.LmWorkload() - -input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - -batch = next(input_queue) -assert type(batch) == dict - -assert 'inputs' in batch -assert 'targets' in batch - -assert type(batch['inputs']) == np.ndarray -assert type(batch['targets']) == np.ndarray - -assert batch['inputs'].dtype == np.int64 -assert batch['targets'].dtype == np.int64 - -assert batch['inputs'].shape == (1, global_batch_size, seq_len) -assert batch['targets'].shape == (1, global_batch_size, seq_len) - -print(f"JAX devices = {jax.devices()}") -print("1") From 6b4ff12356c5f41b01ce703801b556a11079d354 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:58 +0100 Subject: [PATCH 07/63] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++ .../lm/dev/test_build_input_queue_jax.py | 127 ++++++++++++++++++ .../lm/tests/test_build_input_queue_torch.py | 87 ++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_torch.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py new file mode 100644 index 000000000..08354be74 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py @@ -0,0 +1,127 @@ + +# TODO: redo with pmap!! + +import os +import jax +import tensorflow as tf +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_jax.workload import LmWorkload + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + + +N_GPUS = jax.local_device_count() + +print(f"jax.local_devices() = {jax.local_devices()}") +print(f"jax.local_device_count() = {jax.local_device_count()}") + +print(f"N_GPUS = {N_GPUS}") + +def check_batch(batch): + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + +def process_shard(batch): + inputs, targets = batch['inputs'], batch['targets'] + jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) + jax.debug.print("inputs {inputs}", inputs=inputs) + jax.debug.callback(check_batch, batch) + return inputs, targets + +# Apply process_batch across devices, sharding batch across devices +pmap_process = jax.pmap(process_shard, axis_name='batch') + + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = np.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + batch = next(input_queue) + + inputs, targets = batch['inputs'], batch['targets'] + print(f"Processing on GPU with inputs: {inputs.shape}") + + inputs, targets = pmap_process(batch) + print(f"Processing on GPU with inputs: {inputs.shape}") + print(f"Processing on GPU with inputs: {inputs}") + + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs[0]: {inputs[0]}") + # print(f"inputs[1]: {inputs[1]}") + + # for device_id in range(2): + # # Access the sharded data for each GPU + # print(inputs.shape) + # device_inputs = inputs[device_id] + # print(f" GPU {device_id} Inputs: {device_inputs.shape}") + + # @jax.pmap + # def process_batch(batch): + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + + # return inputs, targets + + # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) + # print(f"inputs: {inputs[0]}") + + + +def main(): + test_dataloader_jax() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py new file mode 100644 index 000000000..83a18ec15 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -0,0 +1,87 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # batch = next(input_queue) + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs: {inputs}") + + # Start test. + for _ in range(100): + + batch = next(input_queue) + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + From 3c5c847eb1489fa11a65c98c0f3327bd3c23c088 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:45:41 +0100 Subject: [PATCH 08/63] Stop tracking .gitignore --- .gitignore | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7d35f0ccc..000000000 --- a/.gitignore +++ /dev/null @@ -1,28 +0,0 @@ -__pycache__/* -__pycache__ -*egg-info -*eggs -.vscode/ -env/ -venv/ -workdir/ -makefile -*.out -*.sh -*.swp -*/data/ -*events.out.tfevents* -algoperf/workloads/librispeech_conformer/data_dir -algoperf/workloads/librispeech_conformer/work_dir -*.flac -*.npy -*.csv -*.vocab -wandb/ -*.txt -scoring/plots/ - -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv - -algoperf/_version.py From 20d841b1932408bc905051dc2e188f3a43e0d749 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:47:55 +0100 Subject: [PATCH 09/63] Remove dev/ from repo, keep locally --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ------ .../lm/dev/test_build_input_queue_jax.py | 127 ------------------ 2 files changed, 169 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py deleted file mode 100644 index 08354be74..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py +++ /dev/null @@ -1,127 +0,0 @@ - -# TODO: redo with pmap!! - -import os -import jax -import tensorflow as tf -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_jax.workload import LmWorkload - -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - - -N_GPUS = jax.local_device_count() - -print(f"jax.local_devices() = {jax.local_devices()}") -print(f"jax.local_device_count() = {jax.local_device_count()}") - -print(f"N_GPUS = {N_GPUS}") - -def check_batch(batch): - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - -def process_shard(batch): - inputs, targets = batch['inputs'], batch['targets'] - jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) - jax.debug.print("inputs {inputs}", inputs=inputs) - jax.debug.callback(check_batch, batch) - return inputs, targets - -# Apply process_batch across devices, sharding batch across devices -pmap_process = jax.pmap(process_shard, axis_name='batch') - - -def test_dataloader_jax(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = np.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - batch = next(input_queue) - - inputs, targets = batch['inputs'], batch['targets'] - print(f"Processing on GPU with inputs: {inputs.shape}") - - inputs, targets = pmap_process(batch) - print(f"Processing on GPU with inputs: {inputs.shape}") - print(f"Processing on GPU with inputs: {inputs}") - - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - # print(f"inputs[0]: {inputs[0]}") - # print(f"inputs[1]: {inputs[1]}") - - # for device_id in range(2): - # # Access the sharded data for each GPU - # print(inputs.shape) - # device_inputs = inputs[device_id] - # print(f" GPU {device_id} Inputs: {device_inputs.shape}") - - # @jax.pmap - # def process_batch(batch): - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - - # return inputs, targets - - # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) - # print(f"inputs: {inputs[0]}") - - - -def main(): - test_dataloader_jax() - - -if __name__ == '__main__': - main() - From f3ba0593d955c657b6da8a07eede425509dbc6b9 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:00:44 +0100 Subject: [PATCH 10/63] fix comments --- algoperf/workloads/lm/input_pipeline.py | 2 +- datasets/dataset_setup.py | 27 +++++++------------------ 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e74490a16..bae1f5e45 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -38,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? + dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index aab793832..8299133c1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,8 +711,6 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - # data_dir = "/fast/najroldi/data" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') @@ -726,7 +724,7 @@ def download_finewebedu(data_dir, tmp_dir): 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - # cache_dir=tmp_dir + cache_dir=tmp_dir ) ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size @@ -756,19 +754,11 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - from datasets import load_from_disk - tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len - # TODO: this might take to much memory - # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 - # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples - # NOTE: the current approach leads to data loss at batch boundaries, - # but concatenation *cannot* happen if batched=False, - # because concat_chunck relies on processing multiple examples at once. - # (2) loss happening because of nproc>1: potentially losing the last tokens in each process - # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM + # TODO (nico): this might take to much memory + # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -780,15 +770,12 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck,\ - **map_setup - ) - n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 + lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) + n_tokens = len(lm_dataset) * max_seq_length logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets - # TODO: avoid (single doc) contamination between train and val + # TODO (nico): avoid (single doc) contamination, by splitting before concatenation VAL_TOKENS = 10_000_000 val_samples = VAL_TOKENS // max_seq_length + 1 val_dataset = lm_dataset.select(range(val_samples)) From 381451f04a34e4a78a5256f92e1e7c092e0eadeb Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:46:45 +0100 Subject: [PATCH 11/63] add class specifications --- algoperf/workloads/lm/lm_jax/workload.py | 36 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 26 ++- algoperf/workloads/lm/workload.py | 201 +++++++++++++------ datasets/dataset_setup.py | 6 +- 4 files changed, 199 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 773f8c54c..84377b4bc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,17 +1,47 @@ """LM workload implemented in Jax.""" import functools -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple +from absl import logging from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np +import optax from algoperf import param_utils +from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload - class LmWorkload(BaseLmWorkload): - pass + """LM JAX workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0ff7884c7..404dc2532 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -23,6 +23,24 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -93,6 +111,10 @@ def _build_input_queue(self, } yield batch - - def eval_step(): + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" pass diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 7b1313dd7..e36d54625 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -5,6 +5,9 @@ import os from typing import Any, Dict, Optional, Tuple +from absl import flags +import torch.distributed as dist + import jax import numpy as np import torch @@ -12,17 +15,98 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +FLAGS = flags.FLAGS + USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): - """A LM workload.""" + """LM workload.""" _vocab_size: int = 32000 def __init__(self) -> None: super().__init__() - self._tokenizer = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result['test/ppl'] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + pass + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + pass + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -43,65 +127,58 @@ def _build_input_queue(self, for batch in iter(ds): yield batch - def _eval_model_on_split(): - pass - - def eval_period_time_sec(): - pass - - def has_reached_test_target(): - pass - - def has_reached_validation_target(): - pass - - def init_model_fn(): - pass - - def is_output_params(): - pass - - def loss_fn(): - pass - - def loss_type(): - pass - - def max_allowed_runtime_sec(): - pass - - def model_fn(): - pass - - def num_eval_train_examples(): - pass - - def num_test_examples(): - pass - - def num_train_examples(): - pass - - def num_validation_examples(): - pass - - def step_hint(): - pass - - def test_target_value(): - pass - - def train_mean(): - pass - - def train_stddev(): - pass - - def validation_target_value(): - pass - - def target_metric_name(): - pass + @abc.abstractmethod + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True) + + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {'loss': mean_loss} - def eval_batch_size(): + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ pass + + + diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8299133c1..fb8701f4d 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,11 +711,11 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') - - _maybe_mkdir(tmp_dir) + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ + else os.path.expanduser("~/.cache/huggingface/datasets") _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir From f111d2e8baada7af619504a87974fa78f3e34d55 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:29:37 +0100 Subject: [PATCH 12/63] add workload LM info --- algoperf/workloads/workloads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..6b99a25a6 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -114,6 +114,7 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' }, @@ -150,6 +151,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'lm', 'ogbg', 'wmt' ] From 808d398ee2cf78e92cea29e2d0696eb6ce592929 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:32:48 +0100 Subject: [PATCH 13/63] restore data_utils.py tree map --- algoperf/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 068c21c03..37d1bd20f 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_util.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, From 35f8f8942cb993628f1b20c3d29346e4d7b40e95 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 14:38:41 +0100 Subject: [PATCH 14/63] fixed NFS bug --- datasets/dataset_setup.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index fb8701f4d..a68da3ff5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -708,26 +708,28 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir): +def download_finewebedu(data_dir, tmp_dir=None): """Download FineWebEdu-10B.""" data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ - else os.path.expanduser("~/.cache/huggingface/datasets") + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) - # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - cache_dir=tmp_dir + cache_dir=cache_dir ) - ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + # Shuffle so that multiproc has shards of similar size. + ds = ds.shuffle(seed=1996) seq_len = 2048 max_seq_length = seq_len+1 @@ -754,11 +756,8 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - + # Concat in chunks of max_seq_len - # TODO (nico): this might take to much memory - # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} From cbb6ee67c6eb4828b574987d45fde508e5f1db67 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 15:02:27 +0100 Subject: [PATCH 15/63] train/val split before concat --- datasets/dataset_setup.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index a68da3ff5..5e27211e8 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -756,8 +756,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - - # Concat in chunks of max_seq_len + + # Find how many entries to take from dataset to have VAL_TOKENS in validation set. + VAL_TOKENS = 10_000_000 + tokens_accumulated, num_examples_for_val = 0, 0 + for example in tokenized_dataset: + tokens_accumulated += len(example['input_ids']) + num_examples_for_val += 1 + if tokens_accumulated >= VAL_TOKENS: + break + # Split in train and valid. + val_dataset = tokenized_dataset.select(range(num_examples_for_val)) + train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + + # Concat in chunks of max_seq_len. + # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -769,18 +782,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) - n_tokens = len(lm_dataset) * max_seq_length - logging.info(f"Number of tokens in dataset: {n_tokens:_}") - - # Split dataset into training and validation sets - # TODO (nico): avoid (single doc) contamination, by splitting before concatenation - VAL_TOKENS = 10_000_000 - val_samples = VAL_TOKENS // max_seq_length + 1 - val_dataset = lm_dataset.select(range(val_samples)) - train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + # Concat text in validation and train sets. + val_dataset = val_dataset.map(concat_chunck, **map_setup) + train_dataset = train_dataset.map(concat_chunck, **map_setup) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 868987c2fd72ced8107048e20de44a7e303074e8 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:41:05 +0100 Subject: [PATCH 16/63] renamed datasets to avoid conflict with HF --- {datasets => datasets_algoperf}/README.md | 0 .../dataset_setup.py | 17 ++++++++++------- .../librispeech_preprocess.py | 2 +- .../librispeech_tokenizer.py | 0 4 files changed, 11 insertions(+), 8 deletions(-) rename {datasets => datasets_algoperf}/README.md (100%) rename {datasets => datasets_algoperf}/dataset_setup.py (98%) rename {datasets => datasets_algoperf}/librispeech_preprocess.py (98%) rename {datasets => datasets_algoperf}/librispeech_tokenizer.py (100%) diff --git a/datasets/README.md b/datasets_algoperf/README.md similarity index 100% rename from datasets/README.md rename to datasets_algoperf/README.md diff --git a/datasets/dataset_setup.py b/datasets_algoperf/dataset_setup.py similarity index 98% rename from datasets/dataset_setup.py rename to datasets_algoperf/dataset_setup.py index 5e27211e8..21811e729 100644 --- a/datasets/dataset_setup.py +++ b/datasets_algoperf/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 datasets_algoperf/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -126,15 +126,15 @@ flags.DEFINE_boolean('fastmri', False, 'If --all=false, whether or not to download FastMRI.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('imagenet', False, 'If --all=false, whether or not to download Imagenet.') flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('finewebedu', - False, - 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -727,6 +727,8 @@ def download_finewebedu(data_dir, tmp_dir=None): split='train', cache_dir=cache_dir ) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading + # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) @@ -747,6 +749,7 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_attention_mask=False ) tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info(f"Tokenizing...") tokenized_dataset = ds.map( tokenize, remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', @@ -783,6 +786,7 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: } return result # Concat text in validation and train sets. + logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") @@ -876,9 +880,8 @@ def main(_): download_wmt(data_dir) if FLAGS.all or FLAGS.finewebedu: - if not FLAGS.skip_download: - logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir) + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir, tmp_dir) # pylint: enable=logging-format-interpolation diff --git a/datasets/librispeech_preprocess.py b/datasets_algoperf/librispeech_preprocess.py similarity index 98% rename from datasets/librispeech_preprocess.py rename to datasets_algoperf/librispeech_preprocess.py index a8c5cae1d..cd291e5b3 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets_algoperf/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets import librispeech_tokenizer +from datasets_algoperf import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/datasets_algoperf/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to datasets_algoperf/librispeech_tokenizer.py From dd59dedc97f99e994221775b1e980d845bfb908c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:55:11 +0100 Subject: [PATCH 17/63] renamed datasets to dataset --- {datasets_algoperf => dataset}/README.md | 0 {datasets_algoperf => dataset}/dataset_setup.py | 6 +++--- {datasets_algoperf => dataset}/librispeech_preprocess.py | 2 +- {datasets_algoperf => dataset}/librispeech_tokenizer.py | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename {datasets_algoperf => dataset}/README.md (100%) rename {datasets_algoperf => dataset}/dataset_setup.py (99%) rename {datasets_algoperf => dataset}/librispeech_preprocess.py (98%) rename {datasets_algoperf => dataset}/librispeech_tokenizer.py (100%) diff --git a/datasets_algoperf/README.md b/dataset/README.md similarity index 100% rename from datasets_algoperf/README.md rename to dataset/README.md diff --git a/datasets_algoperf/dataset_setup.py b/dataset/dataset_setup.py similarity index 99% rename from datasets_algoperf/dataset_setup.py rename to dataset/dataset_setup.py index 21811e729..0c7f33de6 100644 --- a/datasets_algoperf/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets_algoperf/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -74,8 +74,8 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer import datasets as hf_datasets from transformers import AutoTokenizer diff --git a/datasets_algoperf/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 98% rename from datasets_algoperf/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index cd291e5b3..b96881332 100644 --- a/datasets_algoperf/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets_algoperf import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets_algoperf/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets_algoperf/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py From 496b9c31f0bdd9a50e18a6907146969fd98e73cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:52:54 +0100 Subject: [PATCH 18/63] fix style --- .gitignore | 28 +++++++++++ algoperf/workloads/lm/input_pipeline.py | 50 ++++++++----------- algoperf/workloads/lm/lm_jax/workload.py | 15 +----- algoperf/workloads/lm/lm_pytorch/workload.py | 46 +++++++++-------- .../lm/tests/test_build_input_queue_torch.py | 18 +++---- algoperf/workloads/lm/workload.py | 12 ++--- 6 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..916a29ff4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +__pycache__/* +__pycache__ +*egg-info +*eggs +.vscode/ +env/ +venv/ +workdir/ +makefile +*.out +*.sh +*.swp +*/data/ +*events.out.tfevents* +algoperf/workloads/librispeech_conformer/data_dir +algoperf/workloads/librispeech_conformer/work_dir +*.flac +*.npy +*.csv +*.vocab +wandb/ +*.txt +scoring/plots/ + +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv + +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index bae1f5e45..53fe79276 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,24 +1,22 @@ """Input pipeline for a LM dataset.""" import functools import os +from typing import Optional -from datasets import Dataset, load_from_disk -from typing import Dict, List, Optional, Union - +from datasets import load_from_disk import jax -import numpy as np import tensorflow as tf -import tensorflow_datasets as tfds from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# This ensures that only the primary process (RANK == 0) uses TensorFlow's # automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal -# number of elements to prefetch or parallelize for dataset operations, improving performance. +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine +# the optimal number of elements to prefetch or parallelize for dataset +# operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -44,25 +42,24 @@ def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + }) # Avoid creating too many threads when using PyTorch DDP. # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread - ds = ds.with_options(options) # apply threading restrictions + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -70,10 +67,7 @@ def tf_generator(): if is_training: ds = ds.repeat() - # Batch the dataset, ensuring the last batch is dropped if not full during training - # i.e. it groups consecutive elements into fixed-size chunks. - # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), - # improving efficiency and parallelism in training + # Batch the dataset, grouping consecutive elements into fixed-size chunks. ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) @@ -83,9 +77,9 @@ def tf_generator(): # Shard the dataset across multiple GPUs/TPUs if necessary ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) - return ds \ No newline at end of file + return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 84377b4bc..64d538dda 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,22 +1,11 @@ """LM workload implemented in Jax.""" -import functools -from typing import Any, Dict, Iterator, Optional, Tuple +from typing import Dict, Optional, Tuple -from absl import logging -from flax import jax_utils -from flax import linen as nn -from flax.training import common_utils -import jax -import jax.numpy as jnp -import numpy as np -import optax - -from algoperf import param_utils -from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 404dc2532..e57d26390 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -3,16 +3,10 @@ import contextlib from typing import Dict, Iterator, Optional, Tuple -from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist -from torch.nn import DataParallel as DP -import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload @@ -41,16 +35,17 @@ def model_fn( update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: pass - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - + seq_len = 2048 # TODO: define it somewehere else DTYPE = torch.int32 # TODO: decide between int32 and int64. @@ -65,20 +60,25 @@ def _build_input_queue(self, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + inputs = torch.as_tensor( + batch['inputs'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor( + batch['targets'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=DTYPE, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -95,12 +95,16 @@ def _build_input_queue(self, dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) - targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) + targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK-1], targets[RANK-1] + inputs, targets = inputs[RANK - 1], targets[RANK - 1] if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 83a18ec15..639e71491 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -1,11 +1,6 @@ - import jax import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec + from algoperf.profiler import PassThroughProfiler from algoperf.pytorch_utils import pytorch_init from algoperf.pytorch_utils import pytorch_setup @@ -29,20 +24,20 @@ def test_dataloader_torch(): seq_len = 2048 local_batch_size = global_batch_size // N_GPUS - + workload = LmWorkload() data_rng = jax.random.PRNGKey(rng_seed) - + input_queue = workload._build_input_queue( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - + print(f"RANK {RANK} of {N_GPUS}") sync_ddp() - + # batch = next(input_queue) # inputs, targets = batch['inputs'], batch['targets'] # print(f"inputs.shape: {inputs.shape}") @@ -71,7 +66,7 @@ def test_dataloader_torch(): assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) - assert torch.equal(inputs[:,1:], targets[:,:-1]) + assert torch.equal(inputs[:, 1:], targets[:, :-1]) print(f"=== ALL TEST PASSED ===") @@ -84,4 +79,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e36d54625..3d04be3c5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -3,14 +3,11 @@ import abc import math import os -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Optional from absl import flags -import torch.distributed as dist - import jax -import numpy as np -import torch +import torch.distributed as dist from algoperf import spec from algoperf.workloads.lm import input_pipeline @@ -155,7 +152,7 @@ def _eval_model_on_split(self, global_batch_size, num_batches, repeat_final_dataset=True) - + for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) loss += self._eval_batch(params, eval_batch) @@ -179,6 +176,3 @@ def loss_fn( (not synced across devices). """ pass - - - From 50989eb6a8a54c43225a4243f770a4419d431a81 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:57:06 +0100 Subject: [PATCH 19/63] fix formatting --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 - submission_runner.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e57d26390..be6c94c46 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,6 +1,5 @@ """LM workload implemented in PyTorch.""" -import contextlib from typing import Dict, Iterator, Optional, Tuple import jax diff --git a/submission_runner.py b/submission_runner.py index d7df006bb..f8a66452d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 5af0fdc1437d924e2e162de5100e66782d01a7e5 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:02:22 +0100 Subject: [PATCH 20/63] fix style --- algoperf/workloads/lm/lm_pytorch/workload.py | 16 ++++++++-------- algoperf/workloads/lm/workload.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index be6c94c46..606f16ad7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -45,8 +45,8 @@ def _build_input_queue( not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - seq_len = 2048 # TODO: define it somewehere else - DTYPE = torch.int32 # TODO: decide between int32 and int64. + seq_len = self._seq_len # TODO: define it somewehere else? + dtype = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -66,10 +66,10 @@ def _build_input_queue( if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=DTYPE, + batch['inputs'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) targets = torch.as_tensor( - batch['targets'], dtype=DTYPE, + batch['targets'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. @@ -77,7 +77,7 @@ def _build_input_queue( if not_train: # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(targets[0]), dtype=DTYPE, device=DEVICE) + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -90,15 +90,15 @@ def _build_input_queue( # Receive batch from rank 0. if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 3d04be3c5..aa6d188b3 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,6 +21,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 32000 + _seq_len: int = 2048 def __init__(self) -> None: super().__init__() From 26830999b92d26c729171cae141ee7abb3409463 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:32:47 +0100 Subject: [PATCH 21/63] fix style --- algoperf/workloads/lm/workload.py | 2 +- dataset/dataset_setup.py | 91 +++++++++++++++++++------------ 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index aa6d188b3..4eb6c74a5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -24,7 +24,7 @@ class BaseLmWorkload(spec.Workload): _seq_len: int = 2048 def __init__(self) -> None: - super().__init__() + pass @property def target_metric_name(self) -> str: diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 0c7f33de6..8f0b09ab7 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -80,7 +80,6 @@ import datasets as hf_datasets from transformers import AutoTokenizer -import math import functools import itertools import os @@ -713,7 +712,9 @@ def download_finewebedu(data_dir, tmp_dir=None): data_dir = os.path.join(data_dir, 'finewebedu') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) @@ -722,75 +723,93 @@ def download_finewebedu(data_dir, tmp_dir=None): os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir - ) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) seq_len = 2048 - max_seq_length = seq_len+1 + max_seq_length = seq_len + 1 map_setup = dict(batched=True, batch_size=1024, num_proc=8) # Tokenize - tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False - ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization logging.info(f"Tokenizing...") tokenized_dataset = ds.map( - tokenize, - remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', - 'language_score', 'token_count', 'score', 'int_score'], - **map_setup - ) - tokenizer.model_max_length = seq_len - + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ], + **map_setup) + lm_tokenizer.model_max_length = seq_len + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have VAL_TOKENS in validation set. - VAL_TOKENS = 10_000_000 + # Find how many entries to take from dataset to have val_tokens in validation set. + val_tokens = 10_000_000 # TODO: decide this value. tokens_accumulated, num_examples_for_val = 0, 0 for example in tokenized_dataset: tokens_accumulated += len(example['input_ids']) num_examples_for_val += 1 - if tokens_accumulated >= VAL_TOKENS: - break + if tokens_accumulated >= val_tokens: + break # Split in train and valid. val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + train_dataset = tokenized_dataset.select( + range(num_examples_for_val, len(tokenized_dataset))) # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + concatenated_examples = { + k: list(itertools.chain(*examples[k])) for k in examples.keys() + } total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length + total_length = (total_length // max_seq_length) * max_seq_length result = { - k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] - for k, t in concatenated_examples.items() + k: [ + t[i:i + max_seq_length] + for i in range(0, total_length, max_seq_length) + ] for k, t in concatenated_examples.items() } return result + # Concat text in validation and train sets. logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" + ) # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 6b7ee29684ee9bf1f9564032f65c09373212c4a4 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:36:27 +0100 Subject: [PATCH 22/63] fix yapf --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index f8a66452d..468a04c7c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 46b645b2ac4a4f4b93fe4ee6324b07f412fb81b3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:38:40 +0100 Subject: [PATCH 23/63] fix style --- dataset/dataset_setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 8f0b09ab7..6587f1439 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -797,7 +797,8 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: k: [ t[i:i + max_seq_length] for i in range(0, total_length, max_seq_length) - ] for k, t in concatenated_examples.items() + ] for k, + t in concatenated_examples.items() } return result From b3ae6474be93f07c578f885bae484773b8a65515 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 15:56:25 +0000 Subject: [PATCH 24/63] HF datasets pipeline --- algoperf/workloads/lm/input_pipeline.py | 75 ++++++++++- .../lm/tests/test_hf_input_pipeline.py | 116 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_hf_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 53fe79276..ea4cb9d63 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -3,12 +3,17 @@ import os from typing import Optional -from datasets import load_from_disk import jax +import jax.numpy as jnp import tensorflow as tf +import torch +import torch.nn.functional as F +from transformers import GPT2Tokenizer from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup +from datasets import load_dataset +from datasets import load_from_disk RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). @@ -20,6 +25,74 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None +def get_hf_dataloader(cache_dir: str, + data_rng: jax.random.PRNGKey, + batch_size: int = 8, + seq_len: int = 32, + framework: str = "torch", + split="train"): + """ + Create a data loader from HuggingFace's FineWeb dataset. + + Args: + cache_dir: Directory to cache the dataset + batch_size: Number of sequences per batch + seq_len: Length of each sequence + framework: Either "torch" or "jax" to specify output tensor type + split: Dataset split to load + """ + # Initialize tokenizer and get vocab size + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + # Load the FineWeb dataset in streaming mode + fw = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-10BT", + split=split, + streaming=True, + cache_dir=cache_dir) + fw = fw.batch(batch_size=batch_size, drop_last_batch=True) + if split in ['train', 'eval_train']: + fw = fw.shuffle(seed=int(data_rng[-1])) + + def _tokenize(x): + """Tokenize and pad text to seq_len+1 tokens.""" + if framework == "torch": + tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) + elif framework == "jax": + tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = jnp.pad( + tokens, + pad_length, + mode="constant", + constant_values=tokenizer.pad_token_id) + return tokens[:seq_len + 1] + + def batch_iterator(): + for doc in fw: + if framework == "torch": + token_ids = torch.stack([_tokenize(x) for x in doc['text']]) + # Take first seq_len+1 tokens and convert to one-hot + tokens = F.one_hot(token_ids, num_classes=vocab_size).float() + # Split into input/target + inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] + inputs, targets = inputs.to("cuda"), targets.to("cuda") + elif framework == "jax": + token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) + tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) + inputs, targets = tokens[:, :-1], tokens[:, 1:] + devices = jax.devices("gpu") + inputs, targets = jax.device_put(inputs), jax.device_put(targets) + yield inputs, targets + + return batch_iterator() + + def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py new file mode 100644 index 000000000..36bab0d02 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py @@ -0,0 +1,116 @@ +"""Tests for LM HuggingFace input pipeline.""" +import os + +import jax +import jax.numpy as jnp +import torch +from transformers import GPT2Tokenizer + +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + +def main(): + # Setup test environment + cache_dir = "/home/ak4605/data" + if not os.path.exists(cache_dir): + raise FileNotFoundError(f"Cache directory {cache_dir} not found") + + data_rng = jax.random.PRNGKey(42) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + + print("Running JAX output shapes and types test...") + batch_size = 8 + seq_len = 32 + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == jnp.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == jnp.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" + assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" + print("✓ JAX test passed") + + print("\nRunning Torch output shapes and types test...") + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="torch", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == torch.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == torch.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" + assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" + print("✓ Torch test passed") + + print("\nTesting consistent batching with same seed...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" + print("✓ Consistent batching test passed") + + print("\nTesting eval split doesn't shuffle...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(999)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" + print("✓ Eval no shuffling test passed") + + print("\nAll tests passed successfully!") + + +if __name__ == "__main__": + main() From f095d4b167dabc0e1aeb925b871f32f427fc22c8 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 17:03:05 +0000 Subject: [PATCH 25/63] Testing with linear model --- algoperf/workloads/lm/input_pipeline.py | 1 - algoperf/workloads/lm/lm_jax/models.py | 18 +++++++++ algoperf/workloads/lm/lm_jax/workload.py | 26 +++++++++++-- algoperf/workloads/lm/lm_pytorch/models.py | 18 +++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 32 +++++++++++++-- .../workloads/lm/tests/test_linear_model.py | 39 +++++++++++++++++++ algoperf/workloads/lm/workload.py | 17 ++------ 7 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/models.py create mode 100644 algoperf/workloads/lm/lm_pytorch/models.py create mode 100644 algoperf/workloads/lm/tests/test_linear_model.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index ea4cb9d63..cc658501e 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -86,7 +86,6 @@ def batch_iterator(): token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] - devices = jax.devices("gpu") inputs, targets = jax.device_put(inputs), jax.device_put(targets) yield inputs, targets diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py new file mode 100644 index 000000000..edfc102fa --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -0,0 +1,18 @@ +from flax import linen as nn +import jax.numpy as jnp + +class LinearModel(nn.Module): + vocab_size: int + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense( + 512, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(inputs) + return nn.Dense( + self.vocab_size, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 64d538dda..30b0c7867 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,8 +2,12 @@ from typing import Dict, Optional, Tuple +import jax.numpy as jnp +from flax import jax_utils +from algoperf import param_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_jax.models import LinearModel class LmWorkload(BaseLmWorkload): @@ -14,18 +18,32 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + model = LinearModel(vocab_size=self._vocab_size) + input_shape = (1, self._seq_len, self._vocab_size) + variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) + model_state, params = variables.pop('params') + + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + model_state = jax_utils.replicate(model_state) + params = jax_utils.replicate(params) + + return params, model_state def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del mode, rng, update_batch_norm # Not used for linear model + inputs = batch['inputs'] + logits = self._model.apply({'params': params, **model_state}, inputs) + return logits, model_state def _eval_batch(self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py new file mode 100644 index 000000000..545763924 --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class LinearLayer(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.bottleneck = nn.Linear(vocab_size, 512) + self.output = nn.Linear(512, vocab_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.bottleneck.weight, std=0.02) + nn.init.zeros_(self.bottleneck.bias) + nn.init.normal_(self.output.weight, std=0.02) + nn.init.zeros_(self.output.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 606f16ad7..3395aa08f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -5,10 +5,13 @@ import jax import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_pytorch.models import LinearLayer USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -21,18 +24,39 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + if hasattr(self, '_model'): + self._model.reset_parameters() + return self._model, None + + torch.manual_seed(rng[0]) + self._model = LinearLayer(vocab_size=self._vocab_size) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del model_state, rng, update_batch_norm # Not used for linear model + model = params + inputs = batch['inputs'].float() # Convert one-hot to float + logits = model(inputs) + return logits, None def _build_input_queue( self, diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py new file mode 100644 index 000000000..31cd1d577 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_linear_model.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import torch + +TEST_SEQ_LEN = 512 + +def test_pytorch_linear(): + from algoperf.workloads.lm.lm_pytorch.models import LinearLayer + vocab_size = 32000 + model = LinearLayer(vocab_size) + + batch_size = 8 + seq_len = TEST_SEQ_LEN + inputs = torch.randn(batch_size, seq_len, vocab_size) + outputs = model(inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not torch.isnan(outputs).any() + +def test_jax_linear(): + from algoperf.workloads.lm.lm_jax.models import LinearModel + + vocab_size = 32000 + seq_len = TEST_SEQ_LEN + batch_size = 8 + model = LinearModel(vocab_size) + rng = jax.random.PRNGKey(0) + params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) + + inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) + outputs = model.apply(params, inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not jnp.isnan(outputs).any() + +if __name__ == '__main__': + test_pytorch_linear() + test_jax_linear() + print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 4eb6c74a5..a06b17fdc 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -20,8 +20,8 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" - _vocab_size: int = 32000 - _seq_len: int = 2048 + _vocab_size: int = 50257 + _seq_len: int = 512 def __init__(self) -> None: pass @@ -106,6 +106,7 @@ def activation(self) -> str: def glu(self) -> bool: return True + @abc.abstractmethod def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -113,17 +114,7 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - ds = input_pipeline.get_lm_dataset( - data_rng, - split, - data_dir, - vocab_size=self._vocab_size, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - - for batch in iter(ds): - yield batch + """Build an input queue for the given split.""" @abc.abstractmethod def _eval_batch(self, From 0c22f3df420968cf820cbcc826f84a61751f95f5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:28:05 -0400 Subject: [PATCH 26/63] lm workload with linear model --- .../workloads/cifar/cifar_jax/workload.py | 11 -- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/lm_jax/models.py | 5 +- algoperf/workloads/lm/lm_jax/workload.py | 82 +++++++++-- algoperf/workloads/lm/lm_pytorch/workload.py | 129 ++++++++++-------- algoperf/workloads/lm/workload.py | 59 ++++---- pyproject.toml | 3 +- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 9 files changed, 187 insertions(+), 118 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..440de64c1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,7 +87,7 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets + yield {'inputs': inputs, 'targets': targets} return batch_iterator() diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..72ee5bd83 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -7,12 +7,13 @@ class LinearModel(nn.Module): @nn.compact def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: x = nn.Dense( - 512, + 10, kernel_init=nn.initializers.normal(0.02), bias_init=nn.initializers.zeros )(inputs) return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..7cb50302f 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,33 +2,57 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using HuggingFace FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="jax", + split=split) + return loader + def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - model = LinearModel(vocab_size=self._vocab_size) + self._model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + print(params_rng) + # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None return params, model_state def model_fn( @@ -40,15 +64,51 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model + del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) + else: + # Dense labels + loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..0d0281690 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,68 +66,38 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - - while True: - # Only iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor( - batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor( - batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - - # Send batch to other devices when using DDP. - if USE_PYTORCH_DDP: - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - # We don't broadcast the shard for RANK 0. - dist.broadcast(inputs[1:], src=0) - dist.broadcast(targets[1:], src=0) - - # RANK 0 extracts his shard. If not DDP, this just flattens. - inputs, targets = inputs[0], targets[0] - - else: - # Receive batch from rank 0. - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) + + dtype = torch.long + is_train = split == 'train' + + for batch in loader: + inputs, targets = batch + + if USE_PYTORCH_DDP: + if not is_train: + # During eval, the batch size of the remainder might be different + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - - # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) - targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) + + # Broadcast to all devices dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) - # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK - 1], targets[RANK - 1] + + if weights is None: + weights = torch.ones(inputs.shape[0], device=DEVICE) if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) @@ -138,10 +108,51 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in PyTorch.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) + loss = -torch.sum(label_batch * log_probs, dim=-1) + else: + # Dense labels + loss = torch.nn.functional.cross_entropy( + logits_batch, + label_batch, + reduction='none') + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..c10bf13e8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,6 +11,7 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS @@ -21,10 +22,13 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 512 + _seq_len: int = 5 + warmup_factor: float = 0.1 def __init__(self) -> None: - pass + super().__init__() + self._param_shapes = None + self._param_types = None @property def target_metric_name(self) -> str: @@ -36,14 +40,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - pass + return 20.0 # Target perplexity - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/ppl'] <= self.test_target_value @property def test_target_value(self) -> float: - pass + return 20.0 # Target perplexity @property def loss_type(self) -> spec.LossType: @@ -51,23 +55,23 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - pass + return 1000000 # Example size @property def num_eval_train_examples(self) -> int: - pass + return 10000 # Subset for evaluation @property def num_validation_examples(self) -> int: - pass + return 50000 @property def num_test_examples(self) -> int: - pass + return 50000 @property def eval_batch_size(self) -> int: - pass + return 8 @property def train_mean(self): @@ -79,16 +83,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - pass + return 3600 * 4 # 4 hours @property def eval_period_time_sec(self) -> int: - pass + return 600 # 10 minutes @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - pass + return 100000 @property def pre_ln(self) -> bool: @@ -116,13 +120,22 @@ def _build_input_queue(self, repeat_final_dataset: bool = False): """Build an input queue for the given split.""" - @abc.abstractmethod def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + + loss_dict = self.loss_fn(batch['targets'], logits) + return loss_dict['summed'] def _eval_model_on_split(self, split: str, @@ -145,9 +158,10 @@ def _eval_model_on_split(self, num_batches, repeat_final_dataset=True) + loss = 0.0 for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) + loss += self._eval_batch(params, eval_batch, model_state, rng) if USE_PYTORCH_DDP: dist.all_reduce(loss) mean_loss = loss.item() / num_examples @@ -155,16 +169,11 @@ def _eval_model_on_split(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. + @abc.abstractmethod def loss_fn( self, - label_batch: spec.Tensor, # Dense or one-hot labels. + label_batch: spec.Tensor, logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" diff --git a/pyproject.toml b/pyproject.toml index f4ebdaee3..745c6c680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algoperf/_version.py" [project.optional-dependencies] # All workloads full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", ] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] @@ -96,6 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +lm = ["transformers", "datasets"] # Frameworks jax_core_deps = [ diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From 99c7b9b70a374a25d6ac29c4f9a0f7c95e57c1aa Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:46:53 -0400 Subject: [PATCH 27/63] add nanodo model --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 345 ++++++++++++++++++ algoperf/workloads/lm/lm_jax/workload.py | 56 ++- .../paper_baselines/adamw/jax/submission.py | 4 +- 3 files changed, 386 insertions(+), 19 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/nanodo_model.py diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py new file mode 100644 index 000000000..d21fd5090 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -0,0 +1,345 @@ +# Self-contained version of the DecoderOnly Transformer from NanoDO + +import dataclasses +from functools import partial + +from flax import linen as nn +import jax +import jax.numpy as jnp + +# =========== Transformer Decoder-only Model ========== + + + +@dataclasses.dataclass +class DoConfig: + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: DoConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD + +@partial(jax.jit, static_argnums=(0,1,2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack([ + jnp.cos(freqs)[None, :, None, :], + jnp.sin(freqs)[None, :, None, :] + ], axis=3) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack([ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] + ], axis=-1) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name="query") + self.multilinear_key = self.multilinear(name="key") + self.multilinear_value = self.multilinear(name="value") + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name="attn_out_proj", + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 + + # Compute attention scores + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: DoConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name="output_proj" + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + batch_size = y_BxL.shape[0] + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print(f"\nModel Configuration:") + print(f" - Model dimension (D): {cfg.D}") + print(f" - Number of heads (H): {cfg.H}") + print(f" - Max sequence length (L): {cfg.L}") + print(f" - Number of layers (N): {cfg.N}") + print(f" - Vocabulary size (V): {cfg.V}") + print(f" - Feed forward dimension (F): {cfg.F}") + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print("\nInitializing model parameters...") + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Total parameters: {param_count:,}") + + # Make a prediction (forward pass) + print("\nRunning forward pass...") + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") + print(f"Output data type: {logits.dtype}") + + # Print sample logits (first 5 positions of the first sequence) + print("\nSample logits (first sequence, first 5 positions, first 5 values):") + for position in range(min(5, L)): + print(f" Position {position}: {logits[0, position, :5]}") + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + # Test the predict function + print("\nTesting predict function...") + # Use a shorter + short_seq = x_BxL[:, :10] + print(f"Input sequence shape: {short_seq.shape}") + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 7cb50302f..9fdfe6f60 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -10,7 +10,8 @@ from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + TransformerDo, DoConfig, init_rope, apply_rope) from algoperf.workloads.lm.input_pipeline import get_hf_dataloader @@ -42,12 +43,22 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - self._model = LinearModel(vocab_size=self._vocab_size) - input_shape = (1, self._seq_len, self._vocab_size) + # Initialize NanoDO transformer model + cfg = DoConfig( + D=512, # model dim + H=8, # num heads + L=self._seq_len, + N=6, # num layers + V=self._vocab_size, + F=2048, # feedforward dim + dtype=jnp.float32 + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + params_rng, init_rng = jax.random.split(rng) - print(params_rng) - # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) - variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.int32)) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -66,6 +77,11 @@ def model_fn( del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] + + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) return logits, None @@ -76,23 +92,29 @@ def loss_fn( mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in JAX.""" - vocab_size = logits_batch.shape[-1] + # Convert one-hot labels to token IDs if needed + if len(label_batch.shape) == len(logits_batch.shape): # one-hot + label_batch = jnp.argmax(label_batch, axis=-1) - if len(label_batch.shape) == len(logits_batch.shape): - # One-hot labels - loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) - else: - # Dense labels - loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + # Reshape for sequence modeling + logits = logits_batch.reshape(-1, logits_batch.shape[-1]) + labels = label_batch.reshape(-1) + + # Compute cross-entropy loss + loss = -jnp.sum( + jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) if mask_batch is not None: - loss = loss * mask_batch + mask = mask_batch.reshape(-1) + loss = loss * mask + n_valid = mask.sum() + else: + n_valid = labels.shape[0] - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] return { - 'summed': loss.sum(), + 'summed': loss, 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss / n_valid # Return per-token loss } def is_output_params(self, param_name: str) -> bool: diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 6c6d19ef8..dca9a6b95 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,7 +75,6 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) - jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -163,7 +162,6 @@ def update_params( replicated, # loss replicated # grad_norm )) - # print(batch) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, @@ -229,6 +227,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 706d9f74046a0f1c90256ae584b45e30a38e4349 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 13:26:15 -0400 Subject: [PATCH 28/63] torch model --- algoperf/param_utils.py | 2 + .../workloads/lm/lm_pytorch/plainlm_model.py | 298 ++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 57 ++-- .../adamw/pytorch/submission.py | 2 + 4 files changed, 341 insertions(+), 18 deletions(-) create mode 100644 algoperf/workloads/lm/lm_pytorch/plainlm_model.py diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..24f981546 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -43,6 +43,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py new file mode 100644 index 000000000..627a0e16d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -0,0 +1,298 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from dataclasses import dataclass +from typing import Tuple + + + +@dataclass +class ModelConfig: + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = False + + +class MLP(nn.Module): + + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + condense_ratio: int = 1): + inv_freqs = 1.0 / (theta**(torch.arange( + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) + t = torch.arange(end, dtype=torch.float32, + device=inv_freqs.device) / condense_ratio + freqs = torch.outer(t, inv_freqs).float() + return torch.stack([ + torch.cos(freqs)[None, :, None, :], + torch.sin(freqs)[None, :, None, :] + ], + dim=4) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - + qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads + + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + k = k.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + v = v.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, + d) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)]) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer('freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], + persistent=False) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + # For debugging + predictions = [] + + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Print top 5 tokens for debugging + if i == 0: + print("\nPyTorch detailed prediction:") + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + predictions.append(next_token.item()) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + print(f" Full predictions step by step: {predictions}") + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith("fc2.weight"): # mlp/glu output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + if n.endswith("w_out.weight"): # attn output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print("Initializing transformer model and running forward pass...") + + seq_length = 512 + + # Define model configuration + config = ModelConfig( + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=768, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=12, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True # Tie embedding and output weights + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0d0281690..45ad0828f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -11,7 +11,7 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.models import LinearLayer +from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -26,11 +26,23 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: if hasattr(self, '_model'): - self._model.reset_parameters() + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() return self._model, None torch.manual_seed(rng[0]) - self._model = LinearLayer(vocab_size=self._vocab_size) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=512, # Model dimension + expand=4, # MLP expansion factor + n_layers=6, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True + ) + self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -46,15 +58,20 @@ def init_model_fn( def model_fn( self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm # Not used for linear model + del model_state, rng, update_batch_norm model = params - inputs = batch['inputs'].float() # Convert one-hot to float + + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + logits = model(inputs) return logits, None @@ -83,13 +100,14 @@ def _build_input_queue( is_train = split == 'train' for batch in loader: - inputs, targets = batch + inputs = batch['inputs'] + targets = batch['targets'] if USE_PYTORCH_DDP: if not is_train: # During eval, the batch size of the remainder might be different per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) + targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # Broadcast to all devices @@ -97,10 +115,8 @@ def _build_input_queue( dist.broadcast(targets, src=0) if weights is None: - weights = torch.ones(inputs.shape[0], device=DEVICE) - - if weights is None: - weights = torch.ones(per_device_batch_size, device=DEVICE) + batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() + weights = torch.ones((batch_size, seq_len), device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, @@ -110,7 +126,7 @@ def _build_input_queue( def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" - return 'output.weight' in param_name or 'output.bias' in param_name + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name def _eval_batch(self, params: spec.ParameterContainer, @@ -121,11 +137,17 @@ def _eval_batch(self, model = params logits, _ = self.model_fn( model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = -torch.sum(targets * log_probs) + # Handle both one-hot and token ID targets + targets = batch['targets'] + if targets.dim() == 3: # one-hot + loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) + else: # token IDs + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction='sum' + ) return loss def loss_fn( self, @@ -146,7 +168,6 @@ def loss_fn( logits_batch, label_batch, reduction='none') - if mask_batch is not None: loss = loss * mask_batch diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..bdeaaf95b 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -173,6 +173,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From c335e341913dc6b1a747f2d3407e71a8d8e66ab6 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 29 May 2025 14:22:50 +0000 Subject: [PATCH 29/63] lm workload dataset integration in jax --- .../workloads/cifar/cifar_jax/workload.py | 11 - algoperf/workloads/lm/input_pipeline.py | 12 +- algoperf/workloads/lm/lm_jax/models.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 68 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 49 +-- algoperf/workloads/lm/workload.py | 313 +++++++++--------- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 8 files changed, 261 insertions(+), 209 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..8f68fcb55 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,19 +87,19 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets - + batch = { + "inputs": inputs, + "targets": targets, + } + yield batch return batch_iterator() def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - vocab_size: int, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - vocab_path: Optional[str] = None): + num_batches: Optional[int] = None): """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..7913f2c67 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -14,5 +14,6 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..6ad0e7d3d 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,16 +2,36 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using pre-cached FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + return loader def init_model_fn( self, @@ -21,14 +41,15 @@ def init_model_fn( model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None + self._model = model return params, model_state def model_fn( @@ -40,15 +61,40 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model - inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + del mode, rng, update_batch_norm, model_state + inputs = jax.nn.one_hot(batch['inputs'], self._vocab_size, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, # One-hot labels. + logits_batch: spec.Tensor, # Dense logits. + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: Optional[float] = 0.0) -> Dict[str, spec.Tensor]: + del mask_batch, label_smoothing + logits_flat = logits_batch.reshape(-1, self._vocab_size) + targets = jax.nn.one_hot(label_batch, self._vocab_size, axis=-1) + targets_flat = targets.reshape(-1, self._vocab_size) + # Cross-entropy loss + loss = -jnp.sum(targets_flat * jax.nn.log_softmax(logits_flat, axis=-1)) + n_valid_examples = logits_flat.shape[0] + return {'summed': loss, 'n_valid_examples': n_valid_examples} + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..2c6862160 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,35 +66,30 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return + batch = next(dataset_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) targets = torch.as_tensor( batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: @@ -138,10 +133,22 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..e6b33e3e4 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,160 +11,171 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ class BaseLmWorkload(spec.Workload): - """LM workload.""" - - _vocab_size: int = 50257 - _seq_len: int = 512 - - def __init__(self) -> None: - pass - - @property - def target_metric_name(self) -> str: - """The name of the target metric (useful for scoring/processing code).""" - return 'ppl' - - def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value - - @property - def validation_target_value(self) -> float: - pass - - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value - - @property - def test_target_value(self) -> float: - pass - - @property - def loss_type(self) -> spec.LossType: - return spec.LossType.SOFTMAX_CROSS_ENTROPY - - @property - def num_train_examples(self) -> int: - pass - - @property - def num_eval_train_examples(self) -> int: - pass - - @property - def num_validation_examples(self) -> int: - pass - - @property - def num_test_examples(self) -> int: - pass - - @property - def eval_batch_size(self) -> int: - pass - - @property - def train_mean(self): - raise NotImplementedError - - @property - def train_stddev(self): - raise NotImplementedError - - @property - def max_allowed_runtime_sec(self) -> int: - pass - - @property - def eval_period_time_sec(self) -> int: - pass - - @property - def step_hint(self) -> int: - """Approx. steps the baseline can do in the allowed runtime budget.""" - pass - - @property - def pre_ln(self) -> bool: - return True - - @property - def attention_temp(self) -> float: - return 1.0 - - @property - def activation(self) -> str: - return 'silu' - - @property - def glu(self) -> bool: - return True - - @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue for the given split.""" - - @abc.abstractmethod - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - num_batches = int(math.ceil(num_examples / global_batch_size)) - if split not in self._eval_iters: - # These iterators will repeat indefinitely. - self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) - - for _ in range(num_batches): - eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 512 + + def __init__(self) -> None: + pass + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return "ppl" + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result["validation/ppl"] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result["test/ppl"] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + return 8 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + # FIXME: should replace this with a real value later. + return 10000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return "silu" + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) + + loss = 0.0 + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch, model_state, rng) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {"loss": mean_loss} + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ + pass diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From af8cce4d61e7f79916d7293127121ebaa4a4d7ce Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 03:20:46 +0000 Subject: [PATCH 30/63] set package versions for transformers and datasets --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 745c6c680..5e9c21f47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] -lm = ["transformers", "datasets"] +lm = ["transformers==4.25.4", "datasets==3.6.0"] # Frameworks jax_core_deps = [ From d68c54e0aa023570abc94cea97f5757bfb0baca8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 04:02:41 +0000 Subject: [PATCH 31/63] use train_test_split method to shuffle and split fineweb-edu dataset --- dataset/dataset_setup.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 6587f1439..7a83a03f6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -770,18 +770,10 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have val_tokens in validation set. - val_tokens = 10_000_000 # TODO: decide this value. - tokens_accumulated, num_examples_for_val = 0, 0 - for example in tokenized_dataset: - tokens_accumulated += len(example['input_ids']) - num_examples_for_val += 1 - if tokens_accumulated >= val_tokens: - break # Split in train and valid. - val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select( - range(num_examples_for_val, len(tokenized_dataset))) + dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset = dataset_split_dict['train'] + val_dataset = dataset_split_dict['test'] # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. From 9737367473f35b206333edc46f9c193ec8dda821 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:45:32 +0000 Subject: [PATCH 32/63] modifications to fwedu datasetup --- dataset/dataset_setup.py | 164 +++++++++++++++++---------------------- 1 file changed, 73 insertions(+), 91 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 7a83a03f6..584189c4a 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -191,6 +191,7 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') FLAGS = flags.FLAGS @@ -707,106 +708,87 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir=None): +def download_finewebedu(data_dir, + tmp_dir=None, + skip_download=False, + skip_tokenization=False): """Download FineWebEdu-10B.""" - data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, - 'lm') if tmp_dir is not None else os.path.expanduser( - '~/.cache/huggingface/datasets') - - _maybe_mkdir(data_dir) - _maybe_mkdir(tmp_dir) - _maybe_mkdir(cache_dir) - - os.environ["TMPDIR"] = tmp_dir - - ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading - # and allow re-chunking with different seq_len? - - # Shuffle so that multiproc has shards of similar size. - ds = ds.shuffle(seed=1996) - - seq_len = 2048 - max_seq_length = seq_len + 1 - map_setup = dict(batched=True, batch_size=1024, num_proc=8) - - # Tokenize - lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") - - def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq - add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return lm_tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False) - - lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization - logging.info(f"Tokenizing...") - tokenized_dataset = ds.map( - tokenize, - remove_columns=[ - 'text', - 'id', - 'dump', - 'url', - 'file_path', - 'language', - 'language_score', - 'token_count', - 'score', - 'int_score' - ], - **map_setup) - lm_tokenizer.model_max_length = seq_len - - tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + if not skip_download: + data_dir = os.path.join(data_dir, 'finewebedu') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ["TMPDIR"] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info("Tokenizing...") + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ],) + + tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + else: + tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] - # Concat in chunks of max_seq_len. - # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. - def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = { - k: list(itertools.chain(*examples[k])) for k in examples.keys() - } - total_length = len(concatenated_examples[list(examples.keys())[0]]) - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length - result = { - k: [ - t[i:i + max_seq_length] - for i in range(0, total_length, max_seq_length) - ] for k, - t in concatenated_examples.items() - } - return result - - # Concat text in validation and train sets. - logging.info(f"Concatenating and chunking...") - val_dataset = val_dataset.map(concat_chunck, **map_setup) - train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info( - f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info( - f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" - ) + # Convert to tensorflow_datasets.Dataset objects + train_dataset = train_dataset.to_tf_dataset() + val_dataset = train_dataset.to_tf_dataset() # Save datasets - train_dataset.save_to_disk(os.path.join(data_dir, f"train")) - val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + train_dataset.Save(os.path.join(data_dir, "train")) + val_dataset.save(os.path.join(data_dir, "val")) + + return def main(_): @@ -893,7 +875,7 @@ def main(_): if FLAGS.all or FLAGS.finewebedu: logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir, tmp_dir) + download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) # pylint: enable=logging-format-interpolation From 1bf0750e094a695176e8e3bc45ffd979abe9e237 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:46:26 +0000 Subject: [PATCH 33/63] rename fwedu data dir --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 584189c4a..ae27aab18 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -715,7 +715,7 @@ def download_finewebedu(data_dir, """Download FineWebEdu-10B.""" if not skip_download: - data_dir = os.path.join(data_dir, 'finewebedu') + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser( From a33339117b4c79d5fa946f4f7ed029087ab5a630 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 20:46:21 +0000 Subject: [PATCH 34/63] fix --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index ae27aab18..289a1faa6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -734,7 +734,7 @@ def download_finewebedu(data_dir, cache_dir=cache_dir) ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) else: - ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) if not skip_tokenization: # Tokenize From 05dc4dd7102670cebb8ac3a8875b34387d57b9b6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 21:22:57 +0000 Subject: [PATCH 35/63] add back batch mapping in tokenization for fwedu --- dataset/dataset_setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 289a1faa6..f50274615 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -769,7 +769,10 @@ def add_eos_batched(seqs): 'token_count', 'score', 'int_score' - ],) + ], + batched=True, + batch_size=1024, + num_proc=8) tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: From b374cf8db62e99e1594dea90b46a7f69a5bb04c6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:12:24 +0000 Subject: [PATCH 36/63] debugging --- dataset/dataset_setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index f50274615..2c46f4ebc 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -779,9 +779,11 @@ def add_eos_batched(seqs): tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. + print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] + print(type(train_dataset)) # Convert to tensorflow_datasets.Dataset objects train_dataset = train_dataset.to_tf_dataset() From c0c1e3c32c46d65cb7511891b32429aeeb05f90c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:13:48 +0000 Subject: [PATCH 37/63] debugging --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 2c46f4ebc..c18e72ea4 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -776,7 +776,7 @@ def add_eos_batched(seqs): tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: - tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. print(type(tokenized_dataset)) From f76dc392fa83a1da25194d401aa03a9dd6dc9c6a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:23:24 +0000 Subject: [PATCH 38/63] debugging --- dataset/dataset_setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index c18e72ea4..414b78609 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,6 +778,7 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset.to_tf_dataset() # Split in train and valid. print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) From e805fa7997daae83deea4e5336801af195270c1a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:45:07 +0000 Subject: [PATCH 39/63] use tfds to shuffle and split dataset --- dataset/dataset_setup.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 414b78609..747d06d27 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,20 +778,18 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) - tokenized_dataset.to_tf_dataset() - # Split in train and valid. - print(type(tokenized_dataset)) - dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) - train_dataset = dataset_split_dict['train'] - val_dataset = dataset_split_dict['test'] - print(type(train_dataset)) - # Convert to tensorflow_datasets.Dataset objects - train_dataset = train_dataset.to_tf_dataset() - val_dataset = train_dataset.to_tf_dataset() + tokenized_dataset = tokenized_dataset.to_tf_dataset() - # Save datasets - train_dataset.Save(os.path.join(data_dir, "train")) + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, "train")) val_dataset.save(os.path.join(data_dir, "val")) return From c9e9abcdf0cc9c817c1683f7a40d94a9372752f3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 03:40:29 +0000 Subject: [PATCH 40/63] add command for fineweb-edu --- dataset/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dataset/README.md b/dataset/README.md index 1aeb83239..50ca11985 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -453,3 +453,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file From e4323deca83a86ad1d703f056157dfcb0e0b1650 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 03:42:16 +0000 Subject: [PATCH 41/63] fix --- dataset/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/README.md b/dataset/README.md index 50ca11985..1bfd9bf73 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -458,7 +458,7 @@ python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_v From `algorithmic-efficiency` run: ```bash -python3 python3 datasets/dataset_setup.py \ +python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --fineweb_edu From f0c6e75ad70cb2c4242014c1522abb3b3bf9aa2e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 3 Oct 2025 06:23:26 +0000 Subject: [PATCH 42/63] update calls to sharing utils --- algoperf/workloads/lm/lm_jax/workload.py | 4 ++-- algoperf/workloads/lm/workload.py | 2 +- .../baselines/external_tuning/jax_nadamw_full_budget.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index e73a5bfaf..81dde95fc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -7,7 +7,7 @@ import optax from flax import jax_utils from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel @@ -79,7 +79,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(params) + params = jax_sharding_utils.replicate(params) model_state = None return params, model_state diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 6b71c7952..2a9777354 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -92,7 +92,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 100000 + return 7000 @property def pre_ln(self) -> bool: diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..6e40cdab1 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,6 +394,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'lm': + return 128 elif workload_name == 'mnist': return 16 else: From f4ffbe709f6a867ea95ae55f4b47032caee98c4a Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 6 Oct 2025 17:09:11 +0000 Subject: [PATCH 43/63] Fix torch sharding issue, update input pipeline and workload classes to use int32 for tensor types and add dropout rate parameter --- algoperf/workloads/lm/input_pipeline.py | 4 +- algoperf/workloads/lm/lm_jax/workload.py | 5 ++- algoperf/workloads/lm/lm_pytorch/workload.py | 37 ++++++++++--------- .../lm/tests/test_build_input_queue_torch.py | 15 +++++--- algoperf/workloads/lm/workload.py | 3 +- 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index db345700e..c010b32af 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -119,8 +119,8 @@ def tf_generator(): ds = tf.data.Dataset.from_generator( tf_generator, output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), }) # Avoid creating too many threads when using PyTorch DDP. diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 81dde95fc..1f6b3c2b2 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -90,8 +90,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm, model_state + update_batch_norm: bool, + dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 36e441e7e..e5dafdd3c 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -6,7 +6,8 @@ import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP - +from itertools import islice +from algoperf import data_utils from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec @@ -84,19 +85,22 @@ def _build_input_queue( num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" - from algoperf.workloads.lm.input_pipeline import get_hf_dataloader - - loader = get_hf_dataloader( - cache_dir=data_dir, + from algoperf.workloads.lm.input_pipeline import get_lm_dataset + local_batch_size = global_batch_size // N_GPUS + + loader = get_lm_dataset( data_rng=data_rng, - batch_size=global_batch_size, - seq_len=self._seq_len, - framework="torch", - split=split) + split=split, + data_dir=data_dir, + global_batch_size=local_batch_size, + num_batches=num_batches + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) seq_len = self._seq_len weights = None - dtype = torch.long + dtype = torch.int32 is_train = split == 'train' for batch in loader: @@ -109,17 +113,16 @@ def _build_input_queue( per_device_batch_size = torch.tensor( targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - + local_batch_size = per_device_batch_size.item() # Broadcast to all devices - dist.broadcast(inputs, src=0) - dist.broadcast(targets, src=0) + #dist.broadcast(inputs, src=0) + #dist.broadcast(targets, src=0) if weights is None: - batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() - weights = torch.ones((batch_size, seq_len), device=DEVICE) + weights = torch.ones((local_batch_size, seq_len), device=DEVICE) batch = { - 'inputs': inputs, - 'targets': targets, + 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), + 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 639e71491..827272037 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -17,9 +17,9 @@ def sync_ddp(): def test_dataloader_torch(): # Test config. rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' + data_dir = '/home/ak4605/data/finewebedu/' split = 'train' - global_batch_size = 8 + global_batch_size = 64 dtype = torch.int32 seq_len = 2048 @@ -44,35 +44,40 @@ def test_dataloader_torch(): # print(f"inputs: {inputs}") # Start test. - for _ in range(100): + for _ in range(1): batch = next(input_queue) + print(f"RANK {RANK} got batch") assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch inputs, targets = batch['inputs'], batch['targets'] - + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") assert type(inputs) == torch.Tensor assert type(targets) == torch.Tensor assert inputs.device == DEVICE assert targets.device == DEVICE - assert inputs.dtype == dtype assert targets.dtype == dtype + print(local_batch_size, seq_len) assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) assert torch.equal(inputs[:, 1:], targets[:, :-1]) + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") print(f"=== ALL TEST PASSED ===") def main(): profiler = PassThroughProfiler() + print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS) pytorch_init(USE_PYTORCH_DDP, RANK, profiler) test_dataloader_torch() diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 2a9777354..986a98297 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -132,7 +132,8 @@ def _eval_batch(self, model_state, spec.ForwardPassMode.EVAL, rng, - update_batch_norm=False) + update_batch_norm=False, + dropout_rate=None) loss_dict = self.loss_fn(batch['targets'], logits) return loss_dict['summed'] From 5c85c7e278ffa540d65b1d49f0bd1d0cad732052 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 6 Oct 2025 17:39:35 +0000 Subject: [PATCH 44/63] test working, lm workload training not working (debugging) --- algoperf/workloads/lm/lm_jax/workload.py | 3 +- .../lm/tests/test_build_input_queue_jax.py | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 1f6b3c2b2..5401ad240 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -33,9 +33,10 @@ def _build_input_queue(self, split=split, data_dir=data_dir, global_batch_size=global_batch_size) + loader = map(jax_sharding_utils.shard_along_batch_dim, loader) return loader - def _build_input_queue(self, + def _build_hf_input_queue(self, data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py new file mode 100644 index 000000000..b9adc70d2 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp + +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads.lm.lm_jax.workload import LmWorkload +import os + +RANK = os.environ.get('RANK', 0) + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/home/ak4605/data/finewebedu/' + split = 'train' + global_batch_size = 64 + dtype = jnp.int32 + seq_len = 2048 + + workload = LmWorkload() + data_rng = jax.random.PRNGKey(rng_seed) + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + for _ in range(1): + + batch = next(input_queue) + print(f"RANK {RANK} got batch") + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") + + jax.debug.inspect_array_sharding(inputs, callback=print) + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (global_batch_size, seq_len) + assert targets.shape == (global_batch_size, seq_len) + + assert jnp.equal(inputs[:, 1:], targets[:, :-1]).all() + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + test_dataloader_jax() + + +if __name__ == '__main__': + main() From a59dfda3a7ce87b5cad550f2332aaf049f59c8f6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 6 Oct 2025 18:33:29 +0000 Subject: [PATCH 45/63] updates to input_pipeline and model spec --- algoperf/workloads/lm/input_pipeline.py | 257 +++++++++---------- algoperf/workloads/lm/lm_jax/nanodo_model.py | 2 +- algoperf/workloads/lm/lm_jax/workload.py | 36 +-- algoperf/workloads/lm/lm_pytorch/workload.py | 5 +- algoperf/workloads/lm/workload.py | 98 +++---- 5 files changed, 187 insertions(+), 211 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index c010b32af..e674170e4 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,154 +1,129 @@ """Input pipeline for a LM dataset.""" + import functools import os from typing import Optional import jax -import jax.numpy as jnp import tensorflow as tf -import torch -import torch.nn.functional as F -from transformers import GPT2Tokenizer from algoperf import data_utils -from algoperf.pytorch_utils import pytorch_setup -from datasets import load_dataset -from datasets import load_from_disk - -RANK = pytorch_setup()[1] -# Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's -# automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine -# the optimal number of elements to prefetch or parallelize for dataset -# operations, improving performance. -AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None - - -def get_hf_dataloader(cache_dir: str, - data_rng: jax.random.PRNGKey, - batch_size: int = 8, - seq_len: int = 32, - framework: str = "torch", - split="train"): + +AUTOTUNE = tf.data.experimental.AUTOTUNE +PAD_ID = -1 + +TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} + +SEQUENCE_LENGTH = 2048 +MAX_CORPUS_CHARS = 1_000_000_000 +SHUFFLE_BUFFER_SIZE = 1_000_000 +VOCAB_SIZE = 50_257 + + +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): + """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: """ - Create a data loader from HuggingFace's FineWeb dataset. - - Args: - cache_dir: Directory to cache the dataset - batch_size: Number of sequences per batch - seq_len: Length of each sequence - framework: Either "torch" or "jax" to specify output tensor type - split: Dataset split to load - """ - # Initialize tokenizer and get vocab size - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - vocab_size = tokenizer.vocab_size - # Load the FineWeb dataset in streaming mode - fw = load_dataset( - "HuggingFaceFW/fineweb-edu", - name="sample-10BT", - split=split, - streaming=True, - cache_dir=cache_dir) - fw = fw.batch(batch_size=batch_size, drop_last_batch=True) - if split in ['train', 'eval_train']: - fw = fw.shuffle(seed=int(data_rng[-1])) - - def _tokenize(x): - """Tokenize and pad text to seq_len+1 tokens.""" - if framework == "torch": - tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() - pad_length = seq_len - tokens.shape[0] - if pad_length > 0: - tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) - elif framework == "jax": - tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() - pad_length = seq_len - tokens.shape[0] - if pad_length > 0: - tokens = jnp.pad( - tokens, - pad_length, - mode="constant", - constant_values=tokenizer.pad_token_id) - return tokens[:seq_len + 1] - - def batch_iterator(): - for doc in fw: - if framework == "torch": - token_ids = torch.stack([_tokenize(x) for x in doc['text']]) - # Take first seq_len+1 tokens and convert to one-hot - tokens = F.one_hot(token_ids, num_classes=vocab_size).float() - # Split into input/target - inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] - inputs, targets = inputs.to("cuda"), targets.to("cuda") - elif framework == "jax": - token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) - tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) - inputs, targets = tokens[:, :-1], tokens[:, 1:] - inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield {'inputs': inputs, 'targets': targets} - - return batch_iterator() - - -def get_lm_dataset(data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None): + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_data_iter(data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None,): + + ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches) + + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) + + return iter(it) + +def get_lm_dataset( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, +): """Load HF dataset and return a TF dataset.""" - - dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) - - is_training = split == "train" - shuffle = split in ['train', 'eval_train'] - - dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? - - def tf_generator(): - """Generates data in a TensorFlow-friendly format.""" - for example in dataset: - yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], - } - - # Create a TensorFlow dataset - ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - }) - - # Avoid creating too many threads when using PyTorch DDP. - # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: - options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) - - if shuffle: - ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) - - if is_training: - ds = ds.repeat() - - # Batch the dataset, grouping consecutive elements into fixed-size chunks. - ds = ds.batch(global_batch_size, drop_remainder=is_training) - ds = ds.prefetch(AUTOTUNE) - - # Limit the dataset to a fixed number of batches if `num_batches` is specified - if num_batches: - ds = ds.take(num_batches) - - # Shard the dataset across multiple GPUs/TPUs if necessary - ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + if split not in TFDS_SPLIT_NAME: + raise NotImplementedError + + shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1) + + data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) + tokens_ds = tf.data.Dataset.load(data_dir) + + # tokens + tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) + + # sequences + sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) + + # get inputs and outputs + sequences_ds = sequences_ds.map( + lambda x: { + 'inputs': x['input_ids'][:SEQUENCE_LENGTH], + 'targets': x['input_ids'][1:], + }, + num_parallel_calls=AUTOTUNE, + ) + + # batch + if split == 'train': + shuffled_sequences_ds = sequences_ds.shuffle( + SHUFFLE_BUFFER_SIZE, seed=shuffle_seed + ) + repeated_sequences_dataset = shuffled_sequences_ds.repeat() + ds = repeated_sequences_dataset.batch( + global_batch_size, drop_remainder=False + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'eval_train': + ds = batch_with_padding( + sequences_ds, + global_batch_size, + padded_shapes={ + 'inputs': (global_batch_size, None), + 'targets': (global_batch_size, None), + }, + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + elif split == 'validation': + ds = batch_with_padding( + sequences_ds, + global_batch_size, + padded_shapes={ + 'inputs': (global_batch_size, None), + 'targets': (global_batch_size, None), + }, + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size return ds diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index d21fd5090..ed469e1bd 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -3,9 +3,9 @@ import dataclasses from functools import partial -from flax import linen as nn import jax import jax.numpy as jnp +from flax import linen as nn # =========== Transformer Decoder-only Model ========== diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 5401ad240..49547fcef 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -4,16 +4,14 @@ import jax import jax.numpy as jnp -import optax -from flax import jax_utils -from algoperf import param_utils -from algoperf import jax_sharding_utils -from algoperf import spec -from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset + +from algoperf import jax_sharding_utils, param_utils, spec +from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_jax.nanodo_model import ( - TransformerDo, DoConfig, init_rope, apply_rope) + DoConfig, + TransformerDo, +) +from algoperf.workloads.lm.workload import BaseLmWorkload class LmWorkload(BaseLmWorkload): @@ -28,7 +26,7 @@ def _build_input_queue(self, """Build an input queue using pre-cached FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_lm_dataset( + loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, @@ -46,14 +44,8 @@ def _build_hf_input_queue(self, """Build an input queue using HuggingFace FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_hf_dataloader( - cache_dir=data_dir, - data_rng=data_rng, - batch_size=global_batch_size, - seq_len=self._seq_len, - framework="jax", - split=split) - return loader + iter = get_data_iter(data_rng, split, data_dir, global_batch_size) + return iter def init_model_fn( self, @@ -63,10 +55,10 @@ def init_model_fn( # Initialize NanoDO transformer model cfg = DoConfig( - D=512, # model dim - H=8, # num heads + D=2048, # model dim + H=16, # num heads L=self._seq_len, - N=6, # num layers + N=12, # num layers V=self._vocab_size, F=2048, # feedforward dim dtype=jnp.float32 @@ -92,7 +84,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e5dafdd3c..5797de654 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -63,9 +63,10 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm + del model_state, rng, update_batch_norm, dropout_rate model = params # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 986a98297..8f17553ff 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -1,21 +1,20 @@ """LM workload parent class.""" import abc +from absl import logging import math import os from typing import Dict, Optional -from absl import flags import jax import torch.distributed as dist +from absl import flags from algoperf import spec -from algoperf.workloads.lm import input_pipeline -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): @@ -63,7 +62,7 @@ def num_eval_train_examples(self) -> int: @property def num_validation_examples(self) -> int: - return 50000 + return 50000 @property def num_test_examples(self) -> int: @@ -111,53 +110,60 @@ def glu(self) -> bool: return True @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): """Build an input queue for the given split.""" - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: """Evaluate the model on a single batch.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - dropout_rate=None) - + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + loss_dict = self.loss_fn(batch['targets'], logits) return loss_dict['summed'] - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) loss = 0.0 for _ in range(num_batches): @@ -168,13 +174,15 @@ def _eval_model_on_split(self, mean_loss = loss.item() / num_examples return {'loss': mean_loss} + # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @abc.abstractmethod def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling.""" From 1c3cb6649b26c87e4bd7afd9c83fac84af9372ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 6 Oct 2025 22:15:38 +0000 Subject: [PATCH 46/63] add defaults for lm workload --- algoperf/workloads/lm/lm_jax/workload.py | 10 +++++----- algoperf/workloads/lm/workload.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 49547fcef..76739b590 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -54,13 +54,13 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig( - D=2048, # model dim - H=16, # num heads + cfg = DoConfig(u + D=self._emb_dim, # embedding dim + H=self._n_heads, # num heads L=self._seq_len, - N=12, # num layers + N=self._n_layers, # num layers V=self._vocab_size, - F=2048, # feedforward dim + F=self._mlp_dim, # feedforward dim dtype=jnp.float32 ) self._model = TransformerDo(cfg) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 8f17553ff..5cc783dba 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,7 +21,11 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 5 + _seq_len: int = 2048 + _emb_dim: int = 1024 + _n_heads: int = 8 + _n_layers: int = 12 + _mlp_dim: int = 4096 warmup_factor: float = 0.1 def __init__(self) -> None: From af91b120b2d5bd055f486aabdb3a881e28f3d231 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 01:03:33 +0000 Subject: [PATCH 47/63] refactor eval pipeline and loss fn for lm --- algoperf/workloads/lm/input_pipeline.py | 8 +- algoperf/workloads/lm/lm_jax/workload.py | 92 +++++++++++-------- algoperf/workloads/lm/lm_pytorch/workload.py | 28 ++++-- algoperf/workloads/lm/workload.py | 52 +++++++---- .../external_tuning/jax_nadamw_full_budget.py | 4 +- submission_runner.py | 2 +- 6 files changed, 116 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e674170e4..91d6ae53c 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -10,13 +10,13 @@ from algoperf import data_utils AUTOTUNE = tf.data.experimental.AUTOTUNE -PAD_ID = -1 +PAD_ID = tf.constant(-1, dtype=tf.int64) TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} -SEQUENCE_LENGTH = 2048 +SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 1_000_000 +SHUFFLE_BUFFER_SIZE = 1024 VOCAB_SIZE = 50_257 @@ -74,7 +74,7 @@ def get_lm_dataset( global_batch_size: int, num_batches: Optional[int] = None, ): - """Load HF dataset and return a TF dataset.""" + """Load preprocessed TF dataset.""" if split not in TFDS_SPLIT_NAME: raise NotImplementedError diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 76739b590..c3d84104b 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,9 +1,11 @@ """LM workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp +import optax +from flax.training import common_utils from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter @@ -54,7 +56,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig(u + cfg = DoConfig( D=self._emb_dim, # embedding dim H=self._n_heads, # num heads L=self._seq_len, @@ -84,7 +86,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed @@ -93,41 +95,58 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in JAX.""" - # Convert one-hot labels to token IDs if needed - if len(label_batch.shape) == len(logits_batch.shape): # one-hot - label_batch = jnp.argmax(label_batch, axis=-1) - - # Reshape for sequence modeling - logits = logits_batch.reshape(-1, logits_batch.shape[-1]) - labels = label_batch.reshape(-1) - - # Compute cross-entropy loss - loss = -jnp.sum( - jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) - - if mask_batch is not None: - mask = mask_batch.reshape(-1) - loss = loss * mask - n_valid = mask.sum() - else: - n_valid = labels.shape[0] + + def compute_weighted_cross_entropy( + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable + """Compute weighted cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + + Returns: + {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) + smoothed_targets = optax.smooth_labels( + common_utils.onehot(targets, self._vocab_size), label_smoothing + ) + per_example_losses = -jnp.sum( + smoothed_targets * jax.nn.log_softmax(logits), axis=-1 + ) + if weights is None: + weights = jnp.ones_like(targets) + per_example_losses = jnp.where(weights, per_example_losses, 0.0) + summed_loss = per_example_losses.sum() + n_valid_examples = weights.sum() return { - 'summed': loss, - 'n_valid_examples': n_valid, - 'per_example': loss / n_valid # Return per-token loss + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def is_output_params(self, param_name: str) -> bool: - """Return whether the given parameter is an output parameter.""" - return param_name.contains('output') + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) + def _eval_batch(self, params: spec.ParameterContainer, @@ -140,5 +159,6 @@ def _eval_batch(self, targets = batch['targets'] # Calculate cross-entropy loss - loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) - return loss + # TODO(kasimbeg): add weights? + loss_metrics = self.compute_weighted_cross_entropy(logits, targets) + return loss_metrics diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 5797de654..ddf99204d 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,18 +1,19 @@ """LM workload implemented in PyTorch.""" -from typing import Dict, Iterator, Optional, Tuple +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple import jax import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from itertools import islice -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec + +from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + ModelConfig, + Transformer, +) from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -153,6 +154,7 @@ def _eval_batch(self, reduction='sum' ) return loss + def loss_fn( self, label_batch: spec.Tensor, @@ -181,3 +183,15 @@ def loss_fn( 'n_valid_examples': n_valid, 'per_example': loss } + +def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 5cc783dba..b1fa3d2a8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -1,13 +1,11 @@ """LM workload parent class.""" import abc -from absl import logging import math import os -from typing import Dict, Optional +from typing import Any, Dict, Optional import jax -import torch.distributed as dist from absl import flags from algoperf import spec @@ -21,7 +19,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 2048 + _seq_len: int = 1024 _emb_dim: int = 1024 _n_heads: int = 8 _n_layers: int = 12 @@ -169,24 +167,38 @@ def _eval_model_on_split( repeat_final_dataset=True, ) - loss = 0.0 + eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch, model_state, rng) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} + metrics = self._eval_batch(params, eval_batch) + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + + return eval_results - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. @abc.abstractmethod + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0, - ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling.""" + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + return self.compute_weighted_cross_entropy( + logits_batch, + label_batch, + weights=mask_batch, + label_smoothing=label_smoothing + ) + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') \ No newline at end of file diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 6e40cdab1..9b4192de2 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -11,7 +11,7 @@ Tuple, Union, ) - +from absl import logging # isort: on import chex import jax @@ -395,7 +395,7 @@ def get_batch_size(workload_name): elif workload_name == 'wmt': return 128 elif workload_name == 'lm': - return 128 + return 64 elif workload_name == 'mnist': return 16 else: diff --git a/submission_runner.py b/submission_runner.py index 1c51ec58f..64a67e781 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -53,7 +53,7 @@ # Environment variables os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false --xla_dump_to=/logs/xla_dump_jax_lm_10_06_bsz64_seq1028 --xla_dump_hlo_as_proto' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 6b55adf5a65184d09d62a734db8fd3b6c33fdce2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:41:09 +0000 Subject: [PATCH 48/63] refactor evaluation pipeline for lm --- algoperf/workloads/lm/input_pipeline.py | 15 ++++++++++--- algoperf/workloads/lm/lm_jax/workload.py | 28 +++++++----------------- algoperf/workloads/lm/workload.py | 26 ++++++++++++---------- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 91d6ae53c..3a2e46923 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from typing import Optional import jax +import numpy as np import tensorflow as tf from algoperf import data_utils @@ -106,7 +107,7 @@ def get_lm_dataset( repeated_sequences_dataset = shuffled_sequences_ds.repeat() ds = repeated_sequences_dataset.batch( global_batch_size, drop_remainder=False - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) + ).prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( sequences_ds, @@ -115,7 +116,11 @@ def get_lm_dataset( 'inputs': (global_batch_size, None), 'targets': (global_batch_size, None), }, - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + ) + ds = ds.map(lambda x: {'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation elif split == 'validation': ds = batch_with_padding( sequences_ds, @@ -124,6 +129,10 @@ def get_lm_dataset( 'inputs': (global_batch_size, None), 'targets': (global_batch_size, None), }, - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size + ) + ds = ds.map(lambda x: {'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index c3d84104b..bb19d6c30 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -28,26 +28,13 @@ def _build_input_queue(self, """Build an input queue using pre-cached FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_data_iter( + ds = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - loader = map(jax_sharding_utils.shard_along_batch_dim, loader) - return loader - - def _build_hf_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue using HuggingFace FineWeb dataset.""" - del num_batches - del repeat_final_dataset - iter = get_data_iter(data_rng, split, data_dir, global_batch_size) - return iter + ds = map(jax_sharding_utils.shard_along_batch_dim, ds) + return ds def init_model_fn( self, @@ -156,9 +143,10 @@ def _eval_batch(self, """Evaluate the model on a single batch.""" logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss # TODO(kasimbeg): add weights? - loss_metrics = self.compute_weighted_cross_entropy(logits, targets) - return loss_metrics + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b1fa3d2a8..b8e1ea144 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,6 +2,7 @@ import abc import math +import numpy as np import os from typing import Any, Dict, Optional @@ -44,11 +45,11 @@ def validation_target_value(self) -> float: return 20.0 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: - return eval_result['test/ppl'] <= self.test_target_value + return True # No test targets @property def test_target_value(self) -> float: - return 20.0 # Target perplexity + return None # No test targets @property def loss_type(self) -> spec.LossType: @@ -60,19 +61,19 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 10000 # Subset for evaluation + return 500 # Subset for evaluation. # TODO(kasimbeg): update @property def num_validation_examples(self) -> int: - return 50000 + return 500 # TODO(kasimbeg update) @property def num_test_examples(self) -> int: - return 50000 + return 0 @property def eval_batch_size(self) -> int: - return 8 + return 32 @property def train_mean(self): @@ -84,7 +85,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 4 # 4 hours + return 3600 * 5 # 4 hours @property def eval_period_time_sec(self) -> int: @@ -93,7 +94,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 7000 + return 54000 @property def pre_ln(self) -> bool: @@ -141,7 +142,7 @@ def _eval_batch( ) loss_dict = self.loss_fn(batch['targets'], logits) - return loss_dict['summed'] + return loss_dict def _eval_model_on_split( self, @@ -170,12 +171,15 @@ def _eval_model_on_split( eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - metrics = self._eval_batch(params, eval_batch) + metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + eval_results['ppl'] = np.exp(eval_results['loss']) + print(eval_results) return eval_results From 210d671fe7e78502cf321a52c0dfcafe6fa3580c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:43:42 +0000 Subject: [PATCH 49/63] remove temporary flag for hlo dumps --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 64a67e781..1c51ec58f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -53,7 +53,7 @@ # Environment variables os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false --xla_dump_to=/logs/xla_dump_jax_lm_10_06_bsz64_seq1028 --xla_dump_hlo_as_proto' +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 0ad7788302fdc8c5ea22379a0f15c047f75988af Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:45:45 +0000 Subject: [PATCH 50/63] fix in workload target condition check --- algoperf/workloads/lm/workload.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b8e1ea144..374b91ce6 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -38,7 +38,7 @@ def target_metric_name(self) -> str: return 'ppl' def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value + return eval_result['validation/ppl'] < self.validation_target_value @property def validation_target_value(self) -> float: @@ -178,9 +178,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']) - print(eval_results) - + eval_results['ppl'] = np.exp(eval_results['loss']) return eval_results @abc.abstractmethod From 01921d5f6d0068e1d92808ad224b50ab19b60b15 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 8 Oct 2025 23:36:28 +0000 Subject: [PATCH 51/63] fix in mlp for glu --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index ed469e1bd..bd7213620 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -44,6 +44,10 @@ def __call__(self, x_BxLxD: jax.Array): linear = partial( nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.F * 2 / 3 hidden_dim = cfg.multiple_of * ( (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of ) From e42045083c1d28aba5fa5dd15f6993d4a8312880 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 10 Oct 2025 04:14:40 +0000 Subject: [PATCH 52/63] Fix OOM error in weighted cross entropy calculation --- algoperf/workloads/lm/lm_jax/workload.py | 44 +++++++++++-------- .../workloads/lm/lm_pytorch/plainlm_model.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index bb19d6c30..c052794c8 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -84,21 +84,19 @@ def model_fn( def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1, - ) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. - Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. - Returns: {'summed': scalar summed loss, 'n_valid_examples': scalar number of valid examples in batch, 'per_example': 1-d array of per-example losses} @@ -108,18 +106,26 @@ def compute_weighted_cross_entropy( f'Incorrect shapes. Got shape {logits.shape} logits and ' f'{targets.shape} targets.' ) - smoothed_targets = optax.smooth_labels( - common_utils.onehot(targets, self._vocab_size), label_smoothing - ) - - per_example_losses = -jnp.sum( - smoothed_targets * jax.nn.log_softmax(logits), axis=-1 - ) - if weights is None: - weights = jnp.ones_like(targets) - per_example_losses = jnp.where(weights, per_example_losses, 0.0) + # Compute log probabilities + log_probs = jax.nn.log_softmax(logits, axis=-1) + # Extract log probability of the target class + # Shape: [batch, length] + target_log_probs = jnp.take_along_axis( + log_probs, + targets[..., None], + axis=-1 + ).squeeze(-1) + # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) + # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. + confidence = 1.0 - label_smoothing + smoothing_term = label_smoothing / self._vocab_size + per_example_losses = -1.0 * (confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1)) + if weights is not None: + per_example_losses = jnp.where(weights, per_example_losses, 0.0) + n_valid_examples = weights.sum() + else: + n_valid_examples = targets.shape[0] * targets.shape[1] summed_loss = per_example_losses.sum() - n_valid_examples = weights.sum() return { 'summed': summed_loss, 'n_valid_examples': n_valid_examples, diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 627a0e16d..225b98767 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -16,7 +16,7 @@ class ModelConfig: n_layers: int n_heads: int rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = False + tie_embeddings: bool = True class MLP(nn.Module): From 3b31ad521d0037f80391de31582517cc291877be Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 10 Oct 2025 04:15:27 +0000 Subject: [PATCH 53/63] fix issue with checkpointing bool --- algoperf/checkpoint_utils.py | 47 ++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 2c8441d9c..00f05ba5d 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,7 +5,7 @@ """ import os -from typing import Sequence, Tuple +from typing import Sequence, Tuple, Optional import numpy as np import torch @@ -14,7 +14,8 @@ from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint from tensorflow.io import gfile # pytype: disable=import-error - +import orbax.checkpoint as ocp +from orbax.checkpoint.type_handlers import NumpyHandler from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -29,6 +30,48 @@ int, ] +class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): + """ + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. + """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) + + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results + +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) def maybe_restore_checkpoint( framework: str, From bbc114fe730e351d3a721d78f6165f343e4c25cb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:33:15 +0000 Subject: [PATCH 54/63] increase buffer size --- algoperf/workloads/lm/input_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 3a2e46923..2fd27113a 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -17,7 +17,7 @@ SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 1024 +SHUFFLE_BUFFER_SIZE = 100_000 VOCAB_SIZE = 50_257 From 2b162e8d87603ad7ae2ac5020a26fd8c2bce974d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:42:19 +0000 Subject: [PATCH 55/63] remove _eval_batch from jax workload --- algoperf/workloads/lm/lm_jax/workload.py | 17 ----------- algoperf/workloads/lm/workload.py | 36 +++++++++++------------- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index c052794c8..801b1e0b4 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -139,20 +139,3 @@ def _normalize_eval_metrics( del num_examples eval_denominator = total_metrics.pop('denominator') return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) - - - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - # TODO(kasimbeg): add weights? - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], - } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 374b91ce6..f5d2cda38 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -124,25 +124,6 @@ def _build_input_queue( ): """Build an input queue for the given split.""" - def _eval_batch( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - ) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - ) - - loss_dict = self.loss_fn(batch['targets'], logits) - return loss_dict def _eval_model_on_split( self, @@ -181,6 +162,23 @@ def _eval_model_on_split( eval_results['ppl'] = np.exp(eval_results['loss']) return eval_results + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + + @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] From 617e1a3f3810bb73f15d998c25e54fa79ef04315 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:45:44 +0000 Subject: [PATCH 56/63] add todo for pytorch _eval_batch cleanup --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index ddf99204d..71a8afd93 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -148,6 +148,7 @@ def _eval_batch(self, if targets.dim() == 3: # one-hot loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) else: # token IDs + # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch. loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), From 64ea658c04a2d13db75ab0b8fd1204cfe43f8746 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:34:05 +0000 Subject: [PATCH 57/63] add target setting algorithm for fineweb edu lm workload --- .../jax_nadamw_target_setting.py | 427 ++++++++++++++++++ .../fineweb_edu_lm/tuning_search_space.json | 11 + 2 files changed, 438 insertions(+) create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py new file mode 100644 index 000000000..9fa6823d5 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -0,0 +1,427 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +# isort: on +import chex +import jax +import jax.numpy as jnp +import optax + +from algoperf import jax_sharding_utils, spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + step_hint = 0.75 * step_hint + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_update_fn + + +def train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, +): + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + # Compute mean loss and grad + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ) + ) + + # Log loss, grad_norm. + if global_step % 1 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'lm': + return 64 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json new file mode 100644 index 000000000..e6945d69a --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -0,0 +1,11 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.0003955553491092581, + "one_minus_beta1": 0.06124602712, + "beta2": 0.9535169492059872, + "weight_decay": 0.03268700808664715, + "warmup_factor": 0.0375 + } +] \ No newline at end of file From b38ade083282348a5000220bf3ca11f79b5c9e9a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:34:49 +0000 Subject: [PATCH 58/63] update step hint for lm workload --- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/workload.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 2fd27113a..04bd90216 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -17,7 +17,7 @@ SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 100_000 +SHUFFLE_BUFFER_SIZE = 1000 VOCAB_SIZE = 50_257 diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index f5d2cda38..b9610f919 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -57,7 +57,7 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - return 1000000 # Example size + return 8_749_870 # sequences of 1024 tokens each @property def num_eval_train_examples(self) -> int: @@ -94,7 +94,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 54000 + return 72000 @property def pre_ln(self) -> bool: @@ -159,7 +159,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']) + eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results From 65369f239a3110748890473cef415dcb087fe6c0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:36:42 +0000 Subject: [PATCH 59/63] update target --- algoperf/workloads/lm/workload.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b9610f919..0bed0b34d 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -38,11 +38,11 @@ def target_metric_name(self) -> str: return 'ppl' def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] < self.validation_target_value + return eval_result['validation/ppl'] <= self.validation_target_value @property def validation_target_value(self) -> float: - return 20.0 # Target perplexity + return 25.5477 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets @@ -73,7 +73,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 32 + return 64 @property def train_mean(self): @@ -85,16 +85,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 5 # 4 hours + return 3600 * 5 # 4 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 600 # 10 minutes + return 600 # 10 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 72000 + return 72_000 @property def pre_ln(self) -> bool: From 6171b2d2fb6a0243993b10d03f0c284eb2c86801 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 23:04:56 +0000 Subject: [PATCH 60/63] update eval split sizes for lm workload and target setting point --- algoperf/workloads/lm/input_pipeline.py | 4 ++-- algoperf/workloads/lm/workload.py | 8 ++++---- .../fineweb_edu_lm/jax_nadamw_target_setting.py | 4 ++-- .../fineweb_edu_lm/tuning_search_space.json | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 04bd90216..79fdfbbcb 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -120,7 +120,7 @@ def get_lm_dataset( ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) - ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( sequences_ds, @@ -133,6 +133,6 @@ def get_lm_dataset( ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) - ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 0bed0b34d..466769d96 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -61,11 +61,11 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 500 # Subset for evaluation. # TODO(kasimbeg): update + return 10_000 # Subset for evaluation. @property def num_validation_examples(self) -> int: - return 500 # TODO(kasimbeg update) + return 100_000 # sequences @property def num_test_examples(self) -> int: @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 5 # 4 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 600 # 10 minutes TODO(kasimbeg): update + return 1200 # 20 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py index 9fa6823d5..1fef611ac 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -170,8 +170,8 @@ def init_optimizer_state( del rng def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. step_hint = 0.75 * step_hint + # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( init_value=0.0, @@ -343,7 +343,7 @@ def update_params( ) # Log loss, grad_norm. - if global_step % 1 == 0 and workload.metrics_logger is not None: + if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step ) diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json index e6945d69a..ce0f75623 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -1,11 +1,11 @@ [ { "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.0003955553491092581, - "one_minus_beta1": 0.06124602712, - "beta2": 0.9535169492059872, - "weight_decay": 0.03268700808664715, - "warmup_factor": 0.0375 + "label_smoothing": 0.0, + "learning_rate": 0.00038418421332238876, + "one_minus_beta1": 0.01564758865, + "beta2": 0.992362328914093, + "weight_decay": 0.25551270901641954, + "warmup_factor": 0.05 } ] \ No newline at end of file From d7a885cd7270dfbd8203f41276c3313ddbd63929 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 17 Oct 2025 04:01:11 +0000 Subject: [PATCH 61/63] Porting workload input pipeline to torch - Added `limit_tf_threads` parameter to `pytorch_init` to control TensorFlow threading based on workload type. Dataloader was going OOM otherwise. - Updated input pipeline to support "None" for weights (for memory). - Modified Transformer model's `forward` method to optionally return loss during training. Should be better to fuse the loss later. - Adjusted torch LM workload configuration for model dimensions and parameters to match jax. - Updated transformers version in `pyproject.toml`, older version seems unavailable. --- algoperf/pytorch_utils.py | 6 +- algoperf/workloads/lm/input_pipeline.py | 8 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 74 ++++++----- algoperf/workloads/lm/lm_pytorch/workload.py | 118 ++++++------------ pyproject.toml | 2 +- submission_runner.py | 3 +- 6 files changed, 90 insertions(+), 121 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..c7537a884 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: +def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -39,7 +39,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: if use_pytorch_ddp: # Avoid tf input pipeline creating too many threads. - if rank != 0: + if rank != 0 and limit_tf_threads: tf.config.threading.set_intra_op_parallelism_threads(1) tf.config.threading.set_inter_op_parallelism_threads(1) @@ -47,10 +47,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: profiler.set_local_rank(rank) # Only log once (for local rank == 0). if rank != 0: - def logging_pass(*args): pass - logging.info = logging_pass # Initialize the process group. dist.init_process_group('nccl') diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 04bd90216..ee54427e1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -107,7 +107,13 @@ def get_lm_dataset( repeated_sequences_dataset = shuffled_sequences_ds.repeat() ds = repeated_sequences_dataset.batch( global_batch_size, drop_remainder=False - ).prefetch(tf.data.experimental.AUTOTUNE) + ) + ds = ds.map(lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + }) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( sequences_ds, diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 225b98767..5de5bf310 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -159,7 +159,7 @@ def __init__(self, cfg): if cfg.tie_embeddings: self.tie_weights() - def forward(self, x): + def forward(self, x, targets=None): # x: (bsz, seqlen) x = self.embed_tokens(x) # (bsz, seqlen, dim) L = x.shape[1] @@ -178,7 +178,12 @@ def forward(self, x): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) - return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100) + return out, loss + return out def predict(self, x, k=1): """Generate k tokens autoregressively. @@ -190,11 +195,6 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ - # For debugging - predictions = [] - - batch_size = x.shape[0] - seq_len = x.shape[1] # Store original input original_input = x.clone() @@ -202,6 +202,7 @@ def predict(self, x, k=1): # Generate k tokens autoregressively for i in range(k): + # Get logits for the entire sequence logits = self(generated_input) @@ -212,24 +213,20 @@ def predict(self, x, k=1): # This is a common issue - the model gets stuck repeating the last token last_token_id = generated_input[:, -1] next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - - # Print top 5 tokens for debugging - if i == 0: - print("\nPyTorch detailed prediction:") - top5_values, top5_indices = torch.topk(next_token_logits[0], 5) - for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): - prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() - print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") - + # Get the most likely token next_token = torch.argmax(next_token_logits, dim=-1) - predictions.append(next_token.item()) # Append the predicted token to the sequence next_token = next_token.unsqueeze(1) # Add sequence dimension generated_input = torch.cat([generated_input, next_token], dim=1) - print(f" Full predictions step by step: {predictions}") + # For debugging, print predictions for the first item in the batch + print("\nPyTorch detailed prediction (first item in batch):") + predicted_sequence = generated_input[0, -k:].tolist() + print(f" Predicted token IDs: {predicted_sequence}") + for i, token_id in enumerate(predicted_sequence): + print(f" Step {i+1}: Predicted token {token_id}") # Return all tokens, not just the last k return original_input, generated_input[:, -k:] @@ -269,30 +266,43 @@ def count_params(self, non_embedding=True): def main(): print("Initializing transformer model and running forward pass...") - seq_length = 512 + seq_length = 1024 # Define model configuration config = ModelConfig( - vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece seq_len=seq_length, # Maximum sequence length - dim=768, # Embedding dimension + dim=1024, # Embedding dimension expand=4.0, # MLP expansion factor n_layers=12, # Number of transformer layers - n_heads=12, # Number of attention heads + n_heads=8, # Number of attention heads rmsnorm_eps=1e-6, # RMSNorm epsilon tie_embeddings=True # Tie embedding and output weights ) - def tie_weights(self): - self.lm_head.weight = self.embed_tokens.weight + # Instantiate the model + model = Transformer(config) + print(f"Model has {model.count_params():,} parameters.") - def count_params(self, non_embedding=True): - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.embed_tokens.weight.numel() - if (not self.lm_head.weight - is self.embed_tokens.weight): # if no weight tying - n_params -= self.lm_head.weight.numel() - return n_params + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f"Running forward pass with input shape: {input_ids.shape}") + logits = model(input_ids) + print(f"Output logits shape: {logits.shape}") + # Run prediction + print("Running prediction...") + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f"Original input shape for prediction: {original_input.shape}") + print(f"Predicted IDs shape: {predicted_ids.shape}") + print(f"Predicted IDs: {predicted_ids}") +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 71a8afd93..e4c03c4f5 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -14,6 +14,7 @@ Transformer, ) from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.input_pipeline import get_data_iter USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -37,10 +38,11 @@ def init_model_fn( cfg = ModelConfig( vocab_size=self._vocab_size, seq_len=self._seq_len, - dim=512, # Model dimension - expand=4, # MLP expansion factor - n_layers=6, # Number of transformer layers - n_heads=8, # Number of attention heads + dim=self._emb_dim, # Model dimension + expand=self._mlp_dim // self._emb_dim, # MLP expansion factor + # FIXME(rka97): fix expansion factor + n_layers=self._n_layers, # Number of transformer layers + n_heads=self._n_heads, # Number of attention heads rmsnorm_eps=1e-6, tie_embeddings=True ) @@ -65,7 +67,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state, rng, update_batch_norm, dropout_rate model = params @@ -87,10 +89,8 @@ def _build_input_queue( num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" - from algoperf.workloads.lm.input_pipeline import get_lm_dataset local_batch_size = global_batch_size // N_GPUS - - loader = get_lm_dataset( + loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, @@ -99,33 +99,12 @@ def _build_input_queue( ) if USE_PYTORCH_DDP: loader = islice(loader, RANK, None, N_GPUS) - seq_len = self._seq_len - weights = None - dtype = torch.int32 - is_train = split == 'train' - for batch in loader: - inputs = batch['inputs'] - targets = batch['targets'] - - if USE_PYTORCH_DDP: - if not is_train: - # During eval, the batch size of the remainder might be different - per_device_batch_size = torch.tensor( - targets.shape[0], dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - local_batch_size = per_device_batch_size.item() - # Broadcast to all devices - #dist.broadcast(inputs, src=0) - #dist.broadcast(targets, src=0) - - if weights is None: - weights = torch.ones((local_batch_size, seq_len), device=DEVICE) batch = { - 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), - 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), - 'weights': weights, + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), + 'weights': None, } yield batch @@ -133,66 +112,41 @@ def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - model = params - logits, _ = self.model_fn( - model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - - # Handle both one-hot and token ID targets - targets = batch['targets'] - if targets.dim() == 3: # one-hot - loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) - else: # token IDs - # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch. - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.size(-1)), - targets.view(-1), - reduction='sum' - ) - return loss - - def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + # FIXME(rka97): Implement label smoothing + def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in PyTorch.""" - vocab_size = logits_batch.shape[-1] + vocab_size = logits.size(-1) - if len(label_batch.shape) == len(logits_batch.shape): + if len(labels.shape) == len(logits.shape): # One-hot labels - log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) - loss = -torch.sum(label_batch * log_probs, dim=-1) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(labels * log_probs, dim=-1) else: # Dense labels loss = torch.nn.functional.cross_entropy( - logits_batch, - label_batch, + logits.view(-1, vocab_size), + labels.view(-1), reduction='none') - if mask_batch is not None: - loss = loss * mask_batch + loss = loss.view_as(labels) + + if weights is not None: + loss = loss * weights - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + n_valid = weights.sum() if weights is not None else torch.tensor(labels.numel(), dtype=torch.float32, device=labels.device) return { 'summed': loss.sum(), 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss, } -def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, Any] - ) -> Dict[str, float]: - """Normalize eval metrics.""" - del num_examples - if USE_PYTORCH_DDP: - for metric in total_metrics.values(): - dist.all_reduce(metric) - total_metrics = {k: v.item() for k, v in total_metrics.items()} - eval_denominator = total_metrics.pop('denominator') - return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 76bcfb7ca..b93c9794e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.25.4", "datasets==3.6.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ diff --git a/submission_runner.py b/submission_runner.py index 1c51ec58f..1c50cd6d9 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -784,7 +784,8 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + limit_tf_threads = (base_workload != 'lm') + pytorch_init(USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: From 1f0439aaf6bbb7f0670a4dc0564a41c86e509270 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 18 Oct 2025 06:41:33 +0000 Subject: [PATCH 62/63] Fix OOM bug in lm eval --- algoperf/random_utils.py | 4 +-- algoperf/workloads/lm/lm_pytorch/workload.py | 28 ++++++++++++++----- algoperf/workloads/lm/workload.py | 15 +++++++--- .../pytorch_nadamw_full_budget.py | 2 ++ 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 1dc773e80..07efa2bdf 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e4c03c4f5..b2ffac18e 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,5 +1,6 @@ """LM workload implemented in PyTorch.""" +import contextlib from itertools import islice from typing import Any, Dict, Iterator, Optional, Tuple @@ -8,7 +9,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( ModelConfig, Transformer, @@ -72,12 +73,23 @@ def model_fn( del model_state, rng, update_batch_norm, dropout_rate model = params - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded + # Set model to eval or train mode based on the mode parameter + if mode == spec.ForwardPassMode.EVAL: + model.eval() + elif mode == spec.ForwardPassMode.TRAIN: + model.train() + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) + return logits, None def _build_input_queue( @@ -90,12 +102,14 @@ def _build_input_queue( repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" local_batch_size = global_batch_size // N_GPUS + # In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np + # from seeing all GPUs via torch.cuda.device_count() loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=local_batch_size, - num_batches=num_batches + num_batches=num_batches, ) if USE_PYTORCH_DDP: loader = islice(loader, RANK, None, N_GPUS) @@ -104,7 +118,7 @@ def _build_input_queue( batch = { 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), - 'weights': None, + 'weights': torch.tensor(batch['weights'], device=DEVICE, dtype=torch.float32) if batch['weights'] is not None else None, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 466769d96..73e784f3a 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -73,7 +73,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 64 + return 256 @property def train_mean(self): @@ -138,6 +138,11 @@ def _eval_model_on_split( ) -> Dict[str, float]: """Run a full evaluation of the model.""" num_batches = int(math.ceil(num_examples / global_batch_size)) + + # Handle edge case where num_batches is 0 (e.g., test split with 0 examples) + if num_batches == 0: + return {'loss': 0.0, 'ppl': 1.0} + if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( @@ -159,7 +164,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']).item() + eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results @@ -173,9 +178,11 @@ def _eval_batch(self, params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), } diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..9b544e380 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -372,6 +372,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From b11c1938447c3cb68a9635ffa75648ec97c3e5d2 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 18 Oct 2025 20:42:14 +0000 Subject: [PATCH 63/63] repeat dataset --- algoperf/workloads/lm/input_pipeline.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7a55e81fd..ab7c64479 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -98,14 +98,12 @@ def get_lm_dataset( }, num_parallel_calls=AUTOTUNE, ) - - # batch + sequences_ds = sequences_ds.repeat() if split == 'train': - shuffled_sequences_ds = sequences_ds.shuffle( + ds = sequences_ds.shuffle( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed ) - repeated_sequences_dataset = shuffled_sequences_ds.repeat() - ds = repeated_sequences_dataset.batch( + ds = ds.batch( global_batch_size, drop_remainder=False ) ds = ds.map(lambda x: {