Skip to content
Draft

sim #74

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ defaults:
- _self_

seed: 42
desired_num_gpus: 64

finetune:
seed: ${..seed}
desired_num_gpus: ${..desired_num_gpus}

actor:
log_each_n_secs: 0
Expand Down
89 changes: 89 additions & 0 deletions conf/debug_finetune_preprocessor_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Debug configuration for finetune+preprocessor+sft mode
# This configuration runs all three components together for testing

# @package _global_
defaults:
- _self_
- streams: local
- wandb: default

# Debug mode configuration
debug:
mode: "finetune+preprocessor+sft"
streams_from: null
place_inference_workers: true
use_existing_llms: false

# Experiment configuration
output_dir: ???
seed: 42
force_restart: false

# Model configuration
model_path: /mnt/llmd/base_models/Qwen2.5-7B
max_seq_length: 2048
batch_size: 100

# Dataset configuration
dataset_loader: pipelinerl.domains.math.load_datasets
dataset_loader_params: {}
train_dataset_names:
- open_reasoner_zero_57k
test_dataset_names:
- aime_2024

# World configuration
world:
replicas: 1
actor_fraction: 1
preprocessor_fraction: 1
finetune_fraction: 2
env_replicas: 0
actor_group_port: 9000
environment_start_port: 7777

# LLM configuration
me:
llm_urls: "http://localhost:8000"

llm:
parameters:
temperature: 1.0
top_p: 0.95
top_k: 50

# Finetune configuration
finetune:
input: "sft_data" # Use SFT data as input
model_class: "causal-language-modeling"
train_batch_size: 1
gradient_accumulation_passes: 1
seq_parallel: 1
seq_packing: false
rl:
kl_coef: 0.0
value_loss_coef: 0.0

# Preprocess configuration
preprocess:
input: "actor"
output: "sft_data"
dataset_buffer_size: 0
ring_buffer_size: 1000

# Streams configuration
streams:
backend: local
base_path: null
port: 6379
save: ""

# Wandb configuration
wandb:
use_wandb: true
project: "debug-finetune-preprocessor-sft"
name: null
tags: ["debug", "finetune", "preprocessor", "sft"]



72 changes: 69 additions & 3 deletions pipelinerl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ def run(self, dataset: list[tuple[str, dict]]):
published_samples = 0
submitted_groups = 0
finished_groups = 0
cumulative_time_for_current_num_llms = 0
cumulative_time_for_desired_num_llms = 0
cumulative_time_to_deduct = 0
expected_rollouts = -1 if self.is_training else len(dataset)
if expected_rollouts > 0:
logger.info(f"Will stop after {expected_rollouts} rollouts")
Expand Down Expand Up @@ -410,7 +413,24 @@ def run(self, dataset: list[tuple[str, dict]]):
logger.info(
f"Max lag is {self.cfg.finetune.max_lag} samples, that makes {lag_groups} additional starting chunks"
)
can_submit_before_update = lag_groups + groups_per_update
#TODO: rm conv RL code
times_for_current_num_llms = []
time_for_desired_num_of_llms = 0
desired_number_of_llms = self.cfg.desired_num_gpus
current_number_of_llms = len(self.llms) # assumes 1 llm per gpu
assert (groups_per_update * current_number_of_llms) % desired_number_of_llms == 0, (
f"groups_per_update * current_number_of_llms {groups_per_update * current_number_of_llms} "
f"should be divisible by desired_number_of_llms {desired_number_of_llms}"
)
groups_per_update_adjusted = groups_per_update * current_number_of_llms // desired_number_of_llms
can_submit_before_update_non_adjusted = lag_groups + groups_per_update
can_submit_before_update = lag_groups + groups_per_update_adjusted
logger.info(
f"We only have {current_number_of_llms} llms instead of {desired_number_of_llms},"
f" thus instead of {groups_per_update} groups per update,"
f" we can submit {groups_per_update_adjusted} groups per update,"
)
start_sampling_time = time.time()
else:
groups_per_update = None
can_submit_before_update = math.inf
Expand All @@ -426,11 +446,54 @@ def run(self, dataset: list[tuple[str, dict]]):

if self.trainer_state.propagated_weight_version > last_trainer_version:
if max_lag is not None:
assert groups_per_update is not None
can_submit_before_update += groups_per_update
assert groups_per_update_adjusted is not None
can_submit_before_update += groups_per_update_adjusted
can_submit_before_update_non_adjusted += groups_per_update
# the weights have been updated, publish the stats of the previous trainer version
trainer_version_to_publish = last_trainer_version
last_trainer_version = self.trainer_state.propagated_weight_version
start_sampling_time = time.time()
times_for_current_num_llms = []
elif published_samples == can_submit_before_update and published_samples < can_submit_before_update_non_adjusted:
end_time = time.time()
time_for_current_num_of_llms = end_time - start_sampling_time
logger.info(
f"Published {published_samples} samples which is less than {can_submit_before_update_non_adjusted}, took {time_for_current_num_of_llms:.2f} seconds."
f" will now increment the number of samples that can be submitted before update to {can_submit_before_update+groups_per_update_adjusted}"
)
times_for_current_num_llms.append(time_for_current_num_of_llms)
start_sampling_time = end_time
if max_lag is not None:
can_submit_before_update += groups_per_update_adjusted
elif published_samples == can_submit_before_update_non_adjusted:
if len(times_for_current_num_llms) < desired_number_of_llms // current_number_of_llms:
end_time = time.time()
time_for_current_num_of_llms = end_time - start_sampling_time
logger.info(
f"Published {published_samples} samples which is equal to {can_submit_before_update}, took {time_for_current_num_of_llms:.2f} seconds."
f" will now increment the number of samples that can be submitted before update to {can_submit_before_update+groups_per_update_adjusted}"
)
times_for_current_num_llms.append(time_for_current_num_of_llms)
time_for_desired_num_of_llms = max(times_for_current_num_llms)
assert len(times_for_current_num_llms) == desired_number_of_llms // current_number_of_llms , (
f"Expected {desired_number_of_llms // current_number_of_llms} times for current number of llms,"
f" but got {len(times_for_current_num_llms)}"
)
time_for_current_num_of_llms = sum(times_for_current_num_llms)
time_to_deduct = time_for_current_num_of_llms - time_for_desired_num_of_llms
cumulative_time_to_deduct += time_to_deduct
cumulative_time_for_current_num_llms += time_for_current_num_of_llms
cumulative_time_for_desired_num_llms += time_for_desired_num_of_llms
wandb.log({
"actor/cumulative_time_for_current_num_llms2": cumulative_time_for_current_num_llms,
"actor/cumulative_time_for_desired_num_llms2": cumulative_time_for_desired_num_llms,
"actor/cumulative_time_to_deduct2": cumulative_time_to_deduct,
})
logger.info(
f"Time on current number of llms {time_for_current_num_of_llms},"
f" time on desired number of llms: {time_for_desired_num_of_llms:.2f} seconds"
f" time to deduct {time_to_deduct} seconds. Total time to deduct {cumulative_time_to_deduct:.2f} seconds"
)

# First, submit all problems you can until the problem queue is full
if not self.is_scheduling_paused:
Expand Down Expand Up @@ -499,6 +562,9 @@ def run(self, dataset: list[tuple[str, dict]]):
"trainer_model_version": trainer_version_to_publish,
"time_since_start": time.time() - loop_start_time,
"groups_in_progress": in_progress,
"cumulative_time_to_deduct": cumulative_time_to_deduct,
"cumulative_time_for_current_num_llms": cumulative_time_for_current_num_llms,
"cumulative_time_for_desired_num_llms": cumulative_time_for_desired_num_llms,
}
trainer_version_to_publish = None
else:
Expand Down
25 changes: 25 additions & 0 deletions pipelinerl/finetune_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,10 @@ def batch_generator_fn():
rl_config = RLConfig(**args.rl)
# samples_per_step will be used to normalize the loss
rl_config.batch_size = samples_per_step
desired_num_of_processes = args.desired_num_gpus
cumulative_time_to_deduct = 0.0
cumulative_time_for_desired_num_of_processes = 0.0
cumulative_time_for_current_num_of_processes = 0.0
while training_metrics.completed_steps < final_train_steps:
# We include time waiting for data in the step time
if first_pass:
Expand All @@ -581,6 +585,7 @@ def batch_generator_fn():
logger.info("next batch should be a sentinel batch")

time_waiting_for_data += time.time() - before_getting_next_batch
after_getting_next_batch = time.time()
# check if too old, don't drop but count
if (
args.max_lag is not None
Expand Down Expand Up @@ -683,8 +688,17 @@ def toggle_sync(sync: bool):
writer.write(trigger_message)

if not do_optimizer_step:
forward_pass_took = time.time() - after_getting_next_batch
forward_pass_took_for_desired_num_of_processes = (
forward_pass_took * (get_accelerator().state.num_processes / desired_num_of_processes)
)
time_to_deduct = forward_pass_took - forward_pass_took_for_desired_num_of_processes
cumulative_time_to_deduct += time_to_deduct
cumulative_time_for_desired_num_of_processes += forward_pass_took_for_desired_num_of_processes
cumulative_time_for_current_num_of_processes += forward_pass_took
continue


target_samples_per_lead += samples_per_lead_per_step
target_samples += samples_per_step

Expand All @@ -710,6 +724,14 @@ def toggle_sync(sync: bool):
optimizer_step_and_zero_grad()
lr_scheduler.step()

forward_pass_took = time.time() - after_getting_next_batch
forward_pass_took_for_desired_num_of_processes = (
forward_pass_took * (get_accelerator().state.num_processes / desired_num_of_processes)
)
time_to_deduct = forward_pass_took - forward_pass_took_for_desired_num_of_processes
cumulative_time_to_deduct += time_to_deduct
cumulative_time_for_desired_num_of_processes += forward_pass_took_for_desired_num_of_processes
cumulative_time_for_current_num_of_processes += forward_pass_took
metrics_dict = {}
time_to_stop = training_metrics.completed_steps >= final_train_steps
time_to_log = training_metrics.completed_steps % args.log_each_n_steps == 0
Expand Down Expand Up @@ -739,6 +761,9 @@ def toggle_sync(sync: bool):
"stats/queue/batches": batch_queue.qsize(),
"stats/time_waiting_for_data": training_metrics.time_waiting_for_data,
"stats/lag": training_metrics.last_broadcasted_version - lag_stats["min_version"],
"stats/cumulative_time_to_deduct": cumulative_time_to_deduct,
"stats/cumulative_time_for_desired_num_of_processes": cumulative_time_for_desired_num_of_processes,
"stats/cumulative_time_for_current_num_of_processes": cumulative_time_for_current_num_of_processes,
"throughput/tokens_perGPU_per_sec": this_worker_tokens / sum(passes_took) if passes_took else 0,
"throughput/tokens_per_step": this_worker_tokens * get_accelerator().state.num_processes,
"throughput/micro_batches_per_step": len(tokens_processed),
Expand Down
5 changes: 5 additions & 0 deletions pipelinerl/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from queue import Empty, Full
from typing import List

import random
import datasets
import transformers
from litellm import BaseModel, Field
Expand Down Expand Up @@ -447,6 +448,7 @@ def run_preprocessing_loop(
current_length = 0
batch_boundary = published_samples + train_batch_size
target_samples_per_lead = samples_per_trainer[0] + samples_per_lead_per_step
cumulative_writing_took = 0.0

# Per-trainer sample tracking (similar to finetune_loop.py)
total_filtered_out = 0 # Track total filtered samples across all batches
Expand Down Expand Up @@ -554,6 +556,7 @@ def run_preprocessing_loop(

batch_done = False
start_writing = time.time()
random.shuffle(processed_entries_queue)
while (len(processed_entries_queue) > 0 and not batch_done) or (cfg.preprocess.dataset_buffer_size and not batch_done):
logger.debug(f"[inner loop] trainer {trainer_id} has {samples_per_trainer[trainer_id]} samples, target is {target_samples_per_lead}")
if cfg.finetune.seq_packing:
Expand Down Expand Up @@ -622,6 +625,7 @@ def run_preprocessing_loop(
f"batch done: {batch_done}"
)
writing_took += time.time() - start_writing
cumulative_writing_took += writing_took

if (
published_samples > last_published_samples
Expand All @@ -638,6 +642,7 @@ def run_preprocessing_loop(
"preprocessor/filtered_out_samples": num_filtered_out,
"preprocessor/total_filtered_out_samples": total_filtered_out,
"preprocessor/dropped_after_preprocessing": processed_entries_queue_popped_data,
"preprocessor/cumulative_writing_took": cumulative_writing_took,
}
if stats_aggregator.has_enough_data():
stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()})
Expand Down