-
Notifications
You must be signed in to change notification settings - Fork 304
Add design patterns / architecture overview to docs #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
You could add a section about the flow: GRPO Repeated Sampling Flow ImplementationThis document explains how this GRPO implementation generates multiple different completions for each prompt when OverviewWhen Complete Flow1. RepeatSampler Creates the Sampling PatternThe return RepeatSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations, # Each prompt index repeated num_generations times
batch_size=self.generation_batch_size // self.num_generations,
repeat_count=self.num_iterations * self.gradient_accumulation_steps,
shuffle=self.shuffle_dataset,
seed=self.args.seed,
) Key insight: 2. Sampling Pattern ExampleFrom the helpful comment in the code, here's what happens with
3. DataLoader Returns Same Prompts Multiple TimesWhen the dataloader uses the repeated indices, it fetches the same prompt multiple times: # For num_generations=3, the batch might look like:
[
{'prompt': "What is 2+2?", 'answer': "4"},
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt again
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt again
{'prompt': "Solve x+1=5", 'answer': "x=4"},
{'prompt': "Solve x+1=5", 'answer': "x=4"}, # Same prompt again
{'prompt': "Solve x+1=5", 'answer': "x=4"} # Same prompt again
] 4. Submission to AsyncBatchGeneratorThe request = BatchRequest(
batch_id=batch_id,
env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks, 'info': all_infos},
# ... other parameters
)
self.async_generator.submit_batch(request) 5. Environment Receives Multiple Identical PromptsThe environment receives this list of prompts where the same prompt appears multiple times: # env_inputs passed to AsyncBatchGenerator:
{
'prompt': ["What is 2+2?", "What is 2+2?", "What is 2+2?", "Solve x+1=5", "Solve x+1=5", "Solve x+1=5"],
'answer': ["4", "4", "4", "x=4", "x=4", "x=4"]
} 6. Environment Processes Each Prompt IndividuallyThe key insight is in the rollout_tasks = [
self._run_single(semaphore, client, model, prompt, answer, task, info, sampling_args, **kwargs)
for prompt, answer, task, info in zip(prompts, answers, tasks, infos) # Each prompt processed separately
] 7. Independent API Calls with Stochastic SamplingFor each prompt in the list (including duplicates), the environment calls its response = client.chat.completions.create(
model=model,
messages=prompt,
**sanitized_args # Contains temperature, top_p, etc.
) The
8. Reward Computation and Advantage CalculationIn def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
# Reshape rewards to group by prompt: (num_prompts, num_generations)
mean_grouped = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped = rewards.view(-1, self.num_generations).std(dim=1)
# Expand back to original shape for normalization
mean_grouped = mean_grouped.repeat_interleave(self.num_generations, dim=0)
std_grouped = std_grouped.repeat_interleave(self.num_generations, dim=0)
# Compute advantages (rewards - baseline)
advantages = rewards - mean_grouped
if self.scale_rewards:
advantages = advantages / (std_grouped + 1e-4)
return advantages 9. Shuffling Before TrainingAfter collecting all completions and computing advantages, the data is shuffled before being split for gradient accumulation: # Concatenate all data for shuffling
full_batch = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": None,
"advantages": advantages,
}
# Shuffle and split for gradient accumulation
full_batch = shuffle_tensor_dict(full_batch)
self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps) This shuffling ensures that completions from the same prompt are mixed across different gradient accumulation steps, improving training stability. Complete Example FlowLet's trace a concrete example with Step 1: RepeatSampler Creates Repeated Indices# Original dataset: ["What is 2+2?", "Solve x+1=5"]
# RepeatSampler yields: [0, 0, 0, 1, 1, 1] Step 2: DataLoader Returns Repeated Prompts# Batch from dataloader:
[
{'prompt': "What is 2+2?", 'answer': "4"},
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt
{'prompt': "Solve x+1=5", 'answer': "x=4"},
{'prompt': "Solve x+1=5", 'answer': "x=4"}, # Same prompt
{'prompt': "Solve x+1=5", 'answer': "x=4"} # Same prompt
] Step 3: Environment Receives Repeated Prompts# env_inputs passed to AsyncBatchGenerator:
{
'prompt': ["What is 2+2?", "What is 2+2?", "What is 2+2?", "Solve x+1=5", "Solve x+1=5", "Solve x+1=5"],
'answer': ["4", "4", "4", "x=4", "x=4", "x=4"]
} Step 4: Independent API Calls with Stochastic SamplingThe environment calls # Call 1: "What is 2+2?" → client.chat.completions.create(...) → "2+2=4"
# Call 2: "What is 2+2?" → client.chat.completions.create(...) → "Let me calculate: 2+2 equals 4"
# Call 3: "What is 2+2?" → client.chat.completions.create(...) → "The answer is 4"
# Call 4: "Solve x+1=5" → client.chat.completions.create(...) → "x+1=5, so x=4"
# Call 5: "Solve x+1=5" → client.chat.completions.create(...) → "Subtract 1: x=5-1=4"
# Call 6: "Solve x+1=5" → client.chat.completions.create(...) → "x=4" Step 5: Advantages Computed Across Groups# Rewards: [0.9, 0.8, 0.7, 0.6, 0.9, 0.5]
# Grouped by prompt: [[0.9, 0.8, 0.7], [0.6, 0.9, 0.5]]
# Mean per group: [0.8, 0.67]
# Advantages: [0.1, 0.0, -0.1, -0.07, 0.23, -0.17] Step 6: Shuffling Before Training# Before shuffling (grouped by prompt):
# prompt_ids: [prompt0_gen0, prompt0_gen1, prompt0_gen2, prompt1_gen0, prompt1_gen1, prompt1_gen2]
# advantages: [0.1, 0.0, -0.1, -0.07, 0.23, -0.17]
# After shuffle_tensor_dict():
# prompt_ids: [prompt1_gen1, prompt0_gen0, prompt1_gen2, prompt0_gen2, prompt1_gen0, prompt0_gen1]
# advantages: [0.23, 0.1, -0.17, -0.1, -0.07, 0.0]
# Split into gradient_accumulation_steps=2:
# Step 0: [prompt1_gen1, prompt0_gen0, prompt1_gen2]
# Step 1: [prompt0_gen2, prompt1_gen0, prompt0_gen1] Implementation Details
|
Good contribution. I've incorporated it with some light edits. |
This is a more architecturally focused addition to the docs. Did it mainly for my own understanding but think it might be useful to others!