diff --git a/conf/base.yaml b/conf/base.yaml index 3d426f4c..995db7c5 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -47,7 +47,7 @@ llm: temperature: 1.0 test_llm: parameters: - max_tokens: 16000 + max_tokens: 8192 temperature: 1.0 top_p: 0.95 top_k: 50 @@ -67,6 +67,7 @@ vllm_config: tensor-parallel-size: 1 pipeline-parallel-size: 1 generation-config: vllm + max_model_len: 10000 world: replicas: 1 @@ -75,7 +76,8 @@ world: preprocessor_fraction: 0 finetune_fraction: 4 - env_replicas: 2 + # Number of environment servers per actor VLLM server + env_replicas_per_actor: 1 actor_group_port: 9000 environment_start_port: 7777 diff --git a/conf/miniwob.yaml b/conf/miniwob.yaml index a5bf8bc2..1454e774 100644 --- a/conf/miniwob.yaml +++ b/conf/miniwob.yaml @@ -1,34 +1,32 @@ defaults: - base + - override streams: redis + - override finetune: ppo + - _self_ world: - actor_fraction: 4 - preprocessor_fraction: 1 - finetune_fraction: 3 + actor_fraction: 2 + preprocessor_fraction: 0 + finetune_fraction: 6 # debug: # mode: actor save_tapes: False -output_dir: results/miniwob_debug/${now:%Y-%m-%d}/${now:%H-%M-%S} +output_dir: results/miniwob/${now:%Y-%m-%d}/${now:%H-%M-%S} model_path: meta-llama/Llama-3.1-8B-Instruct finetune: - save_checkpoint_steps: 10 - seq_length: 4096 + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples train_batch_size: 1 gradient_accumulation_passes: 1024 - learning_rate: 1e-6 - optim: adamw_torch - rl: - kl_coef: 0.01 # GRPO beta coefficient - reward_minus_kl_coef: 0.0 # RLOO beta coefficient - use_advantages: true - algo: grpo + +eval_every_n_versions: 10240 # 1024 effective bs * 10 "optim steps" llm: parameters: - max_tokens: 3072 + max_tokens: 4096 # output tokens temperature: 1.0 test_llm: parameters: @@ -39,24 +37,37 @@ test_llm: vllm_config: vllm_kwargs: - enable-auto-tool-choice: "" - tool-call-parser: llama3_json # use hermes for qwen - chat_template: pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja # copy pasted from https://github.com/vllm-project/vllm/blob/main/examples/tool_chat_template_llama3.1_json.jinja - enforce-eager: "" # speed the actor llm startup a bit + max_model_len: 16384 # input + output tokens actor: - rollout_policy: pipelinerl.miniwob.rollouts.generate_miniwob_rollout + rollout_policy: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout shared_memory_entry_size: 100000000 + llm_max_rollouts: 32 preprocess: - shared_memory_entry_size: 1000000000 + n_workers: 32 # Increase from 8 + chunk_n_groups: 8 # Increase from 2 for better throughput + # queue for loaded raw groups + raw_queue_size: 32 # Increase from 8 + # queue for processed chunks of multiple groups + input_queue_size: 64 # Increase from 32 + # queue for ready chunks for multiple groups + output_queue_size: 64 # Increase from 32 + # ring buffer to replace old samples with new ones when training is slow + ring_buffer_size: 1024 # Increase from 128 + # "virtual" sample queue per lead trainer + max_ready_samples_per_lead: 256 # Increase from 64 + shared_memory_entry_size: 1000000000 # Increase from 100M # AGENT CONFIGURATION agent_max_loops: 10 # max number of agent - environment interactions for each task +agent_attempts: 3 # number of attempts to run the agent (retry on errors) +rollout_timeout: 600 # overall timeout for entire rollout in seconds (10 minutes) +reward_computation: nico agent: _target_: tapeagents.agent.Agent name : web_agent - max_iterations: 4 # max number of iterations (make_prompt + llm? + generate_steps) for each loop + max_iterations: 4 # max number of iterations (make_prompt + llm + generate_steps) for each loop store_llm_calls: true templates: system_prompt: | @@ -65,50 +76,64 @@ agent: Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. You will be provided with the content of the current page and a task from the user. Do not express your emotions or opinions about the user question. - allowed_tools: | - You have access to the following tools: - {tools_description} - thought_format: | - Important! Respond with the plain text, do not include any JSON or code. - Do not output anything besides what I asked in this message. + allowed_steps: | + You are allowed to produce ONLY steps with the following json schemas: + {allowed_steps} + Do not reproduce schema when producing the steps, use it as a reference. + json_format: | + Important! Respond with very simple parsable JSON! + Do not use any special characters or code. Do not use new lines, tabs, or any other formatting inside the JSON. + Do not output anything besides one simple JSON object. nodes: - _target_: examples.rl_webagent.agent.WebNode name: set_goal system_prompt: ${agent.templates.system_prompt} guidance: | - Produce the thought that describes the intended solution to the task. In the reasoning lines: + Produce the reasoning_thought step that describes the intended solution to the task. In the reasoning lines: - review the instructions from the user and the content of the page. - outline the main task to be accomplished and the steps to be taken to achieve it. - produce definiton of done, that will be checked later to verify if the task was completed. - ${agent.templates.thought_format} - steps_prompt: ${agent.templates.allowed_tools} + Produce only one reasoning_thought step! + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} + steps: + - tapeagents.steps.ReasoningThought trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages max_chars_page_observation: 3000 # keep up to 3000 chars in PageObservation steps - _target_: examples.rl_webagent.agent.WebNode name: reflect system_prompt: ${agent.templates.system_prompt} guidance: | - Review the current state of the page and previous steps to find the best possible next action to accomplish the task. - Produce the reflection_thought to describe the current page state, reflect on your last action, describe what is left to do, and what will be the immediate next action. - Produce only one reflection_thought step! - ${agent.templates.thought_format} - steps_prompt: ${agent.templates.allowed_tools} + Produce the reasoning_thought step that describes the current state of the page, the previous actions, and what should be the next best action to accomplish the task. In the reasoning lines: + - think about which information could be relevant to the given task, note relevant BIDs and coordinates. + - describe the last action taken, what were its expected effects on the page, versus the actual effects you can observe. Are they the same or not? if not, what could have gone wrong? + - check if you are stuck with repeating the same action over and over again, if so, try something else and change the action. + - check if you think the task is done, if not give a detailed list of actions to do next to accomplish the task. + - finally, if the task is not done, describe the immediate next action to be performed and its expected effect on the page. + Produce only one reasoning_thought step! Be brief and to the point. You can skip some details if they are not relevant for this step. + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} + steps: + - tapeagents.steps.ReasoningThought trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages max_chars_page_observation: 3000 # keep up to 3000 chars in PageObservation steps - _target_: examples.rl_webagent.agent.WebNode name: act system_prompt: ${agent.templates.system_prompt} guidance: | - Produce the single next tool call to be performed with the current page. - If you think that the task is solved, call the FinalAnswer. + Produce the next action to be performed with the current page. + If you think that the task is solved, produce the final_answer_action. You can interact with the page elements using their BIDs or coordinates as arguments for actions. HINTS: - You can use the BIDs of the elements or the mouse position in x, y coordinates to interact with them. - - To select value in a dropdown or combobox, ALWAYS use SelectOption tool. + - To select value in a dropdown or combobox, ALWAYS use select_action. - To click on a checkbox or radio button, ALWAYS use BID (or coordinates) of the corresponding Text and not the BID (or coordinates) of the element itself. - Press enter key to submit the search query. + - Always produce only one step at a time. + - Step kind is always lowercase and underscore separated. + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} use_known_actions: true - use_function_calls: true steps: - examples.rl_webagent.steps.FinalAnswerAction trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages @@ -119,18 +144,18 @@ agent: # ENVIRONMENT CONFIGURATION start_attempts: 3 # number of attempts to start each task environment: - _target_: pipelinerl.miniwob.environment_server.WebEnvironmentServer - miniwob_url: file:///home/toolkit/miniwob-plusplus/miniwob/html/miniwob/ - n_envs: 64 + _target_: pipelinerl.domains.miniwob.environment_server.WebEnvironmentServer + miniwob_url: ??? + n_envs: 32 host: "0.0.0.0" - max_session_inactivity_secs: 300 + env_call_timeout: 60 # timeout for each environment call (e.g. start_task, act, etc.) web_env_target: examples.rl_webagent.environment.WebEnvironment - exp_path: ${output_dir}/env_server + exp_path: null headless: true observation_format: html # DATASET CONFIGURATION -dataset_loader: pipelinerl.miniwob.load_tasks.load_tasks +dataset_loader: pipelinerl.domains.miniwob.load_tasks.load_tasks dataset_loader_params: train_split: 0.6 # 0.6 of tasks for training, 0.4 for testing seeds: [0, 42, 1337, 900, 103] diff --git a/conf/miniwob_grpo.yaml b/conf/miniwob_grpo.yaml new file mode 100644 index 00000000..f6cfeed3 --- /dev/null +++ b/conf/miniwob_grpo.yaml @@ -0,0 +1,10 @@ +defaults: + - miniwob + - override finetune: grpo + - _self_ + +finetune: + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples + train_batch_size: 1 + gradient_accumulation_passes: 1024 diff --git a/conf/miniwob_massimo_grpo.yaml b/conf/miniwob_massimo_grpo.yaml new file mode 100644 index 00000000..b61dcf32 --- /dev/null +++ b/conf/miniwob_massimo_grpo.yaml @@ -0,0 +1,15 @@ +defaults: + - miniwob_grpo + - _self_ + +train_dataset_names: + - massimo_train +test_dataset_names: + - massimo_test + +reward_computation: massimo + +finetune: + gradient_accumulation_passes: 512 + +eval_every_n_versions: 5120 # 512 effective bs * 10 "optim steps" diff --git a/conf/miniwob_massimo_ppo.yaml b/conf/miniwob_massimo_ppo.yaml new file mode 100644 index 00000000..53703d56 --- /dev/null +++ b/conf/miniwob_massimo_ppo.yaml @@ -0,0 +1,15 @@ +defaults: + - miniwob + - _self_ + +train_dataset_names: + - massimo_train +test_dataset_names: + - massimo_test + +reward_computation: massimo + +finetune: + gradient_accumulation_passes: 512 + +eval_every_n_versions: 5120 # 512 effective bs * 10 "optim steps" diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index a85f156e..bcce006b 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -196,6 +196,7 @@ async def rollout_and_maybe_produce_result( f"groups in progress: {len(group_rollouts)}, " f"rollouts started so far: {started_rollouts}, " f"rollouts finished so far: {finished_rollouts}, " + f"groups started so far: {group_id}, " f"max group size in bytes: {result_queue.max_actual_entry_size()}, " ) last_logged = time.time() @@ -463,6 +464,9 @@ def run(self, dataset: list[tuple[str, dict]]): assert isinstance(rollout_results, list) assert isinstance(rollout_results[0], RolloutResult) + assert len(rollout_results) == attempts, ( + f"Expected {attempts} rollouts, got {len(rollout_results)}" + ) group_samples = sum(len(r.training_texts) for r in rollout_results) published_samples += group_samples @@ -479,7 +483,6 @@ def run(self, dataset: list[tuple[str, dict]]): f" {in_progress} groups in progress" ) - self.update_stats(rollout_results=rollout_results) finished_groups += 1 diff --git a/pipelinerl/domains/miniwob/README.md b/pipelinerl/domains/miniwob/README.md new file mode 100644 index 00000000..e9af1b42 --- /dev/null +++ b/pipelinerl/domains/miniwob/README.md @@ -0,0 +1,34 @@ +# Miniwob example + +## Prerequesites + +### TapeAgents + +Clone [TapeAgents](https://github.com/ServiceNow/TapeAgents/) in your parent folder and install it. +```bash +cd .. +git clone git@github.com:ServiceNow/TapeAgents.git +cd TapeAgents +pip install -e . +pip install 'tapeagents[finetune,converters]=0.1.12' +cd ../PipelineRL +``` + +Make sure to add the TapeAgent folder to your python path. +```bash +export PYTHONPATH="/path/to/TapeAgents:$PYTHONPATH" +``` + +### Miniwob + +see setup here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/README.md + +### Playwright + +The environment server will need to have playwright installed. + +`playwright install` + +## Launch Command + +`python -m pipelinerl.launch --config-name miniwob environment.miniwob_url=file:///PATH/TO/miniwob-plusplus/miniwob/html/miniwob/` diff --git a/pipelinerl/miniwob/environment_server.py b/pipelinerl/domains/miniwob/environment_server.py similarity index 80% rename from pipelinerl/miniwob/environment_server.py rename to pipelinerl/domains/miniwob/environment_server.py index 13839f7a..b30f9ef7 100644 --- a/pipelinerl/miniwob/environment_server.py +++ b/pipelinerl/domains/miniwob/environment_server.py @@ -13,12 +13,14 @@ def __init__(self, exp_path: str, headless: bool = True, observation_format: str = "html", - max_session_inactivity_secs: int = 600, + env_call_timeout: int = 60, ): os.environ["MINIWOB_URL"] = miniwob_url + # Remote environment server configuration self.n_envs = n_envs self.host = host - self.max_session_inactivity_secs = max_session_inactivity_secs + self.env_call_timeout = env_call_timeout + # Individual web environment configuration self.web_env_target = web_env_target self.exp_path = exp_path self.headless = headless @@ -29,7 +31,7 @@ def launch(self, port: int): """ Serve the web environment in TapeAgent. """ - env_server = EnvironmentServer(n_envs=self.n_envs, host=self.host, port=port, max_session_inactivity_secs=self.max_session_inactivity_secs) + env_server = EnvironmentServer(n_envs=self.n_envs, host=self.host, port=port, env_call_timeout=self.env_call_timeout) env_server.launch(OmegaConf.create({ "_target_": self.web_env_target, "exp_path": self.exp_path, diff --git a/pipelinerl/domains/miniwob/load_tasks.py b/pipelinerl/domains/miniwob/load_tasks.py new file mode 100644 index 00000000..a056a311 --- /dev/null +++ b/pipelinerl/domains/miniwob/load_tasks.py @@ -0,0 +1,216 @@ +import random +from browsergym.miniwob import ALL_MINIWOB_TASKS + +DEBUG_SPLIT = [ + "miniwob.buy-ticket", + "miniwob.bisect-angle", + "miniwob.choose-list", + "miniwob.click-checkboxes-large", + "miniwob.click-checkboxes-soft", +] +EASY_SPLIT = [ + "miniwob.click-color", + "miniwob.click-test-2", + "miniwob.click-test-transfer", + "miniwob.enter-password", + "miniwob.focus-text-2", + "miniwob.identify-shape", + "miniwob.navigate-tree", + "miniwob.phone-book", + "miniwob.read-table", + "miniwob.use-autocomplete", + "miniwob.use-autocomplete", + "miniwob.buy-ticket", + "miniwob.click-checkboxes-soft", + "miniwob.click-collapsible-2", + "miniwob.click-collapsible-2-nodelay", + "miniwob.click-collapsible-nodelay", + "miniwob.click-dialog-2", + "miniwob.click-tab-2", + "miniwob.click-tab-2-medium", + "miniwob.form-sequence-3", + "miniwob.hot-cold", + "miniwob.multi-orderings", + "miniwob.tic-tac-toe", + "miniwob.use-autocomplete-nodelay" +] +MASSIMO_TRAIN_SPLIT = [ + "miniwob.ascending-numbers", + "miniwob.bisect-angle", + "miniwob.book-flight", + "miniwob.choose-date", + "miniwob.choose-date-easy", + "miniwob.choose-date-medium", + "miniwob.choose-date-nodelay", + "miniwob.choose-list", + "miniwob.circle-center", + "miniwob.click-button-sequence", + "miniwob.click-checkboxes-soft", + "miniwob.click-checkboxes-transfer", + "miniwob.click-collapsible-2", + "miniwob.click-collapsible-2-nodelay", + "miniwob.click-collapsible-nodelay", + "miniwob.click-color", + "miniwob.click-dialog", + "miniwob.click-dialog-2", + "miniwob.click-link", + "miniwob.click-menu", + "miniwob.click-menu-2", + "miniwob.click-scroll-list", + "miniwob.click-shape", + "miniwob.click-tab", + "miniwob.click-tab-2", + "miniwob.click-tab-2-hard", + "miniwob.click-tab-2-medium", + "miniwob.click-test", + "miniwob.click-test-2", + "miniwob.click-test-transfer", + "miniwob.click-widget", + "miniwob.copy-paste", + "miniwob.copy-paste-2", + "miniwob.count-shape", + "miniwob.count-sides", + "miniwob.daily-calendar", + "miniwob.drag-box", + "miniwob.drag-circle", + "miniwob.drag-cube", + "miniwob.drag-items", + "miniwob.drag-items-grid", + "miniwob.drag-shapes", + "miniwob.drag-shapes-2", + "miniwob.drag-sort-numbers", + "miniwob.draw-circle", + "miniwob.draw-line", + "miniwob.email-inbox", + "miniwob.email-inbox-delete", + "miniwob.email-inbox-forward", + "miniwob.email-inbox-forward-nl", + "miniwob.email-inbox-forward-nl-turk", + "miniwob.email-inbox-important", + "miniwob.email-inbox-noscroll", + "miniwob.email-inbox-reply", + "miniwob.email-inbox-star-reply", + "miniwob.enter-date", + "miniwob.enter-text", + "miniwob.enter-text-dynamic", + "miniwob.enter-time", + "miniwob.find-greatest", + "miniwob.find-word", + "miniwob.focus-text-2", + "miniwob.form-sequence", + "miniwob.form-sequence-2", + "miniwob.generate-number", + "miniwob.grid-coordinate", + "miniwob.guess-number", + "miniwob.highlight-text", + "miniwob.hot-cold", + "miniwob.identify-shape", + "miniwob.login-user", + "miniwob.login-user-popup", + "miniwob.multi-layouts", + "miniwob.multi-orderings", + "miniwob.navigate-tree", + "miniwob.odd-or-even", + "miniwob.order-food", + "miniwob.phone-book", + "miniwob.read-table", + "miniwob.read-table-2", + "miniwob.resize-textarea", + "miniwob.right-angle", + "miniwob.scroll-text", + "miniwob.scroll-text-2", + "miniwob.search-engine", + "miniwob.sign-agreement", + "miniwob.simple-algebra", + "miniwob.social-media", + "miniwob.social-media-all", + "miniwob.social-media-some", + "miniwob.text-editor", + "miniwob.text-transform", + "miniwob.tic-tac-toe", + "miniwob.use-autocomplete", + "miniwob.use-autocomplete-nodelay", + "miniwob.use-colorwheel", + "miniwob.use-colorwheel-2", + "miniwob.use-spinner", + "miniwob.visual-addition", +] +MASSIMO_TEST_SPLIT = [ + "miniwob.buy-ticket", + "miniwob.click-button", + "miniwob.click-option", + "miniwob.click-pie-nodelay", + "miniwob.drag-single-shape", + "miniwob.email-inbox-nl-turk", + "miniwob.enter-text-2", + "miniwob.find-midpoint", + "miniwob.focus-text", + "miniwob.simple-arithmetic", + "miniwob.stock-market", + "miniwob.use-slider-2", + "miniwob.click-checkboxes", + "miniwob.click-checkboxes-large", + "miniwob.click-collapsible", + "miniwob.click-pie", + "miniwob.click-shades", + "miniwob.click-tab-2-easy", + "miniwob.enter-password", + "miniwob.form-sequence-3", + "miniwob.highlight-text-2", + "miniwob.unicode-test", + "miniwob.use-slider", +] +TRAIN_SPLIT = None +TEST_SPLIT = None + + +def load_tasks(dataset_names: list[str], train_split: float = 0.6, seeds: list[int] = [0, 1, 2, 3, 4]): + # set global variables if needed + global TRAIN_SPLIT, TEST_SPLIT + if TRAIN_SPLIT is None or TEST_SPLIT is None: + # Make a copy of tasks to avoid modifying the original + all_tasks = list(ALL_MINIWOB_TASKS) + # Use fixed seed for consistent shuffling + rng = random.Random(1406) + rng.shuffle(all_tasks) + + n_train_tasks = int(len(ALL_MINIWOB_TASKS) * train_split) + TRAIN_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[:n_train_tasks]] + TEST_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[n_train_tasks:]] + + tasks = [] + for name in dataset_names: + if name == "debug": + tasks.extend([ + # {"dataset": "miniwob.debug", "task": task, "seed": 0} for task in DEBUG_SPLIT + {"dataset": task, "task": task, "seed": 0} for task in DEBUG_SPLIT + ]) + elif name == "easy": + tasks.extend([ + # {"dataset": "miniwob.easy", "task": task, "seed": 0} for task in EASY_SPLIT + {"dataset": task, "task": task, "seed": 0} for task in EASY_SPLIT + ]) + elif name == "train": + tasks.extend([ + # {"dataset": "miniwob.train", "task": task, "seed": seed} + {"dataset": task, "task": task, "seed": seed} + for task in TRAIN_SPLIT for seed in seeds + ]) + elif name == "test": + tasks.extend([ + # {"dataset": "miniwob.test", "task": task, "seed": seed} + {"dataset": task, "task": task, "seed": seed} + for task in TEST_SPLIT for seed in seeds + ]) + elif name == "massimo_train": + tasks.extend([ + {"dataset": task, "task": task, "seed": seed} + for task in MASSIMO_TRAIN_SPLIT for seed in range(3,10) # seeds 0-2 are used for held out goals in Mass setup + ]) + elif name == "massimo_test": + tasks.extend([ + {"dataset": task, "task": task, "seed": seed} + for task in MASSIMO_TEST_SPLIT for seed in range(10) + ]) + return tasks + diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py new file mode 100644 index 00000000..ec71ff8e --- /dev/null +++ b/pipelinerl/domains/miniwob/rollouts.py @@ -0,0 +1,341 @@ +import asyncio +import json +import logging +import os +import random +import time +import traceback + +import aiohttp +from examples.rl_webagent.steps import WebTape +from hydra.utils import instantiate +from omegaconf import DictConfig +from tapeagents.agent import DEFAULT, Agent +from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, Observation +from tapeagents.io import save_json_tape +from tapeagents.llms.trainable import TrainableLLM +from tapeagents.orchestrator import async_execute_agent +from tapeagents.remote_environment import AsyncRemoteEnvironment +from tapeagents.tools.simple_browser import PageObservation + +from pipelinerl.async_llm import make_training_text +from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.world import Job + +logger = logging.getLogger(__name__) + + +class MiniwobMetrics(BaseMetrics): + reward: float + success: bool + no_error: bool + no_answer: bool + overflow: bool + n_llm_calls: int + n_step_errors: int + n_page_observations: int + n_steps: int + total_execution_time: float + agent_execution_time: float + environment_execution_time: float + env_step_time: float + agent_step_time: float + + +def tape_contains_an_error(tape: WebTape) -> bool: + """ + Returns true if the tape ends with an error, ie if one of the following is true: + - the last step is an LLMOutputParsingFailureAction + - the tape metadata has an error + - the last step is a PageObservation with an error + """ + return ( + len(tape.steps) == 0 + or isinstance(tape.steps[-1], LLMOutputParsingFailureAction) + or tape.metadata.result.get("error") is not None + or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error) + ) + + +async def check_env_server_health(env_job: Job, session: aiohttp.ClientSession) -> dict: + """Check environment server health via HTTP API.""" + try: + url = f"http://{env_job.hostname}:{env_job.port}/health" + async with session.get(url, timeout=5) as response: + if response.status == 200: + health_data = await response.json() + return { + "healthy": True, + "health_data": health_data, + "last_check": time.time() + } + else: + error_text = await response.text() + return {"healthy": False, "error_message": f"HTTP {response.status}: {error_text}", "last_check": time.time()} + except Exception as e: + exception_type = type(e).__name__ + exception_message = str(e) if str(e) else "No message available" + logger.exception(f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True) + return {"healthy": False, "error_message": f"Exception: {exception_type}: {exception_message}", "last_check": time.time(), "error_stacktrace": traceback.format_exc()} + + +async def generate_miniwob_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + # choose a random environment server + # Generate environment + # Generate TapeAgent + # run the agent + # get llm calls from tape + # compute rewards + # get training text from llm calls + + start_time = time.time() + + # Overall timeout for the entire rollout to prevent hanging + rollout_timeout = getattr(cfg, 'rollout_timeout', 600) # 10 minutes default + + env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] + env_jobs_url_tried = [] + + # Try each environment server with health checks until one of them returns a rollout result + for _ in range(len(env_jobs)): + # Choose the next environment server to try randomly from the ones that have not been tried yet + env_job = random.choice([job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried]) + env_job_url = f"http://{env_job.hostname}:{env_job.port}" + env_jobs_url_tried.append(env_job_url) + + # Check server health before using + health = await check_env_server_health(env_job, session) + if not health["healthy"]: + logger.warning(f"Environment server {env_job_url} is unhealthy: {health}") + logger.warning(f"Get health error stacktrace: {health['error_stacktrace']}") + continue + # Log health status for monitoring + if health["healthy"]: + logger.info(f"Using healthy environment server {env_job_url}: {health}") + + try: + # Execute the entire rollout with a timeout + return await asyncio.wait_for( + _execute_rollout_with_timeout(cfg, llm, problem, session, start_time, env_job_url), + timeout=rollout_timeout + ) + except asyncio.TimeoutError: + health = await check_env_server_health(env_job, session) + if stack_trace := health.get("error_stacktrace"): + logger.warning(f"Get health error stacktrace: {stack_trace}") + logger.warning(f"Rollout timeout error stacktrace: {traceback.format_exc()}") + logger.warning(f"Rollout timed out after {rollout_timeout} seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + continue + except Exception as e: + health = await check_env_server_health(env_job, session) + if stack_trace := health.get("error_stacktrace"): + logger.warning(f"Get health error stacktrace: {stack_trace}") + logger.warning(f"Rollout failed error stacktrace: {traceback.format_exc()}") + logger.warning(f"Rollout failed for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + continue + # If all servers failed + logger.error(f"All environment servers failed for task {problem['dataset']}/{problem['task']}/{problem['seed']}. Returning a failed rollout result.") + return _create_failed_rollout_result(problem, start_time, "all environment servers failed") + + +async def _execute_rollout_with_timeout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, + start_time: float, + env_job_url: str, +) -> RolloutResult: + # (2) Generate environment, TapeAgent, and run them to get a Tape + no_error = True # track if there was an error in the tape + environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore + async with environment.acontext(session, wait_for_env=True) as env: + start_attempts = cfg.start_attempts + t = time.perf_counter() + while start_attempts > 0: + try: + tape_dict, info = await env.start_task(problem) + if info.get("error"): + raise ValueError(info['error']) + break + except Exception as e: + start_attempts -= 1 + logger.warning(f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']}. {start_attempts} attempts remaining. Error: {e}") + if start_attempts <= 0: + logger.error(f"Failed to start task after all retry attempts: {e}") + no_error = False + tape_dict = {} + break + else: + logger.warning("Retry start task after 5 seconds.") + await asyncio.sleep(5) + logger.info( + f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds. Worker ID: {env.worker_id}. Tape dict: {tape_dict}" + ) + tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object + t = time.perf_counter() + if no_error: # only run the agent if the task started successfully + logger.info(f"Running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}") + agent_attempts = cfg.agent_attempts + while agent_attempts > 0: + # check if the worker is alive. + try: + # this will either raise RuntimeError if worker is not alive anymore, or return a dictionary with the worker status + worker_status = await env.check_worker_alive() + if worker_status.get("status") == "starting": + logger.warning(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is starting, waiting 5 seconds for it to be fully started.") + await asyncio.sleep(5) + continue + except Exception as e: + # if worker is dead, no need to retry + logger.exception(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is dead. Error: {e}", stack_info=True) + no_error = False + break + # if worker is alive, run the agent + try: + actions = await env.a_actions() + tools_description = await env.a_tools_description() + agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) + agent.llms = {DEFAULT: llm} + tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) + # Check if the tape has an error from the orchestrator (e.g., SocketTimeoutError, RuntimeError: Worker is not alive, etc.) + if tape.metadata.error: + logger.error(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} returned a tape with error: {tape.metadata.error}") + raise ValueError(tape.metadata.error) + else: + # Success - break out of retry loop + logger.info(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} finished successfully") + break + except Exception as e: + agent_attempts -= 1 + logger.warning(f"Error occurred while running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}. {agent_attempts} attempts remaining. Error: {e}") + if agent_attempts <= 0: + logger.error(f"Agent execution failed after all retry attempts for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}: {e}") + no_error = False + break + else: + logger.warning(f"Retry agent execution after 5 seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}.") + await asyncio.sleep(5) + logger.info( + f"Agent finished task {problem['dataset']}/{problem['task']}/{problem['seed']} in {time.perf_counter() - t:.2f} seconds with worker ID: {env.worker_id} and tape ID {tape.metadata.id}" + ) + tape.metadata.result.update({"total_execution_time": time.perf_counter() - t}) + + # save the tape as we go + if cfg.save_tapes: + save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id) + + # (3) Compute rewards + obs_steps = [step for step in tape if isinstance(step, Observation)] + if obs_steps: + last_obs = obs_steps[-1] + # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 + # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L188 + # Let's take directly the RAW_REWARD_GLOBAL from the metadata + # raw_reward = last_obs.metadata.other.get("reward", 0.0) + raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) + else: + raw_reward = -1.0 + + no_error = no_error and not tape_contains_an_error(tape) + # get the number of LLMOutputParsingFailureAction in the tape + n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) + # get the number of PageObservation steps in the tape + n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)]) + + if cfg.reward_computation == "nico": + reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 + elif cfg.reward_computation == "massimo": + reward = float(raw_reward>0) + if reward == 0.0: + reward = -1.0 + reward *= 0.98 ** n_page_observations + else: + raise ValueError(f"Invalid reward configuration: {cfg.reward_computation}") + + # (3) Get LLM calls from Tape + llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] + n_llm_calls = len(llm_calls) + llm_calls: list[LLMCall] = [ + LLMCall(**step.metadata.other["llm_call"]) + if isinstance(step.metadata.other["llm_call"], dict) + else step.metadata.other["llm_call"] + for step in llm_calls + ] + + # (4) # For each LLM interaction in the tape, make a training example. + all_finished = 1 + prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] + output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + for text in training_texts: + text.reward = reward + all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 + + latency = time.time() - start_time + agent_time = tape.metadata.result.get("agent_execution_time", -1.0) + env_time = tape.metadata.result.get("environment_execution_time", -1.0) + n_observations = len([s for s in tape.steps if isinstance(s, Observation)]) # TODO: is this not the same n_page_observations?? + n_other_steps = len(tape.steps) - n_observations + metrics = MiniwobMetrics( + reward=reward, + success=reward > 0.5, + no_error=no_error, + no_answer=reward < 0, + overflow=not all_finished, + n_llm_calls=n_llm_calls, + n_step_errors=n_step_errors, + n_page_observations=n_page_observations, + n_steps=len(tape.steps), + total_execution_time=tape.metadata.result.get("total_execution_time", -1.0), + agent_execution_time=agent_time, + environment_execution_time=env_time, + env_step_time=env_time / n_observations if env_time > 0 and n_observations > 0 else -1.0, + agent_step_time=agent_time / n_other_steps if agent_time > 0 and n_other_steps > 0 else -1.0, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + prompt_tokens=prompt_tokens, + output_tokens=output_tokens, + ) + + +def _create_failed_rollout_result(problem: dict, start_time: float, error_type: str) -> RolloutResult: + """Create a failed rollout result for timeout or other errors.""" + latency = time.time() - start_time + + # Create empty training texts and metrics for failed rollout + metrics = MiniwobMetrics( + reward=-1.0, + success=False, + no_error=False, + no_answer=True, + overflow=False, + n_llm_calls=0, + n_step_errors=0, + n_page_observations=0, + n_steps=0, + total_execution_time=latency, + agent_execution_time=-1.0, + environment_execution_time=-1.0, + env_step_time=-1.0, + agent_step_time=-1.0, + ) + + return RolloutResult( + training_texts=[], + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + prompt_tokens=[], + output_tokens=[], + ) diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index b03ab8d7..ac87457e 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -71,6 +71,13 @@ def validate_config(cfg: DictConfig): if not hasattr(cfg.finetune.rl, "value_loss_coef") or cfg.finetune.rl.value_loss_coef <= 0.0: raise ValueError("value_loss_coef must be greater than 0 when using causal-language-modeling-with-value-head") + if cfg.finetune.seq_length < cfg.vllm_config.vllm_kwargs.max_model_len: + raise ValueError( + f"seq_length {cfg.finetune.seq_length} must be greater than or equal to " + f"vllm_kwargs.max_model_len {cfg.vllm_config.vllm_kwargs.max_model_len}" + ) + + def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus: list[int], exp_dir: Path): kwargs = cfg.vllm_config.vllm_kwargs diff --git a/pipelinerl/miniwob/load_tasks.py b/pipelinerl/miniwob/load_tasks.py deleted file mode 100644 index e5056c80..00000000 --- a/pipelinerl/miniwob/load_tasks.py +++ /dev/null @@ -1,76 +0,0 @@ -import random -from browsergym.miniwob import ALL_MINIWOB_TASKS - -DEBUG_SPLIT = [ - "miniwob.buy-ticket", - "miniwob.bisect-angle", - "miniwob.choose-list", - "miniwob.click-checkboxes-large", - "miniwob.click-checkboxes-soft", -] -EASY_SPLIT = [ - "miniwob.click-color", - "miniwob.click-test-2", - "miniwob.click-test-transfer", - "miniwob.enter-password", - "miniwob.focus-text-2", - "miniwob.identify-shape", - "miniwob.navigate-tree", - "miniwob.phone-book", - "miniwob.read-table", - "miniwob.use-autocomplete", - "miniwob.use-autocomplete", - "miniwob.buy-ticket", - "miniwob.click-checkboxes-soft", - "miniwob.click-collapsible-2", - "miniwob.click-collapsible-2-nodelay", - "miniwob.click-collapsible-nodelay", - "miniwob.click-dialog-2", - "miniwob.click-tab-2", - "miniwob.click-tab-2-medium", - "miniwob.form-sequence-3", - "miniwob.hot-cold", - "miniwob.multi-orderings", - "miniwob.tic-tac-toe", - "miniwob.use-autocomplete-nodelay" -] -TRAIN_SPLIT = None -TEST_SPLIT = None - - -def load_tasks(dataset_names: list[str], train_split: float = 0.6, seeds: list[int] = [0, 1, 2, 3, 4]): - # set global variables if needed - global TRAIN_SPLIT, TEST_SPLIT - if TRAIN_SPLIT is None or TEST_SPLIT is None: - # Make a copy of tasks to avoid modifying the original - all_tasks = list(ALL_MINIWOB_TASKS) - # Use fixed seed for consistent shuffling - rng = random.Random(1406) - rng.shuffle(all_tasks) - - n_train_tasks = int(len(ALL_MINIWOB_TASKS) * train_split) - TRAIN_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[:n_train_tasks]] - TEST_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[n_train_tasks:]] - - tasks = [] - for name in dataset_names: - if name == "debug": - tasks.extend([ - {"dataset": "miniwob.debug", "task": task, "seed": 0} for task in DEBUG_SPLIT - ]) - elif name == "easy": - tasks.extend([ - {"dataset": "miniwob.easy", "task": task, "seed": 0} for task in EASY_SPLIT - ]) - elif name == "train": - tasks.extend([ - {"dataset": "miniwob.train", "task": task, "seed": seed} - for task in TRAIN_SPLIT for seed in seeds - ]) - elif name == "test": - tasks.extend([ - {"dataset": "miniwob.test", "task": task, "seed": seed} - for task in TEST_SPLIT for seed in seeds - ]) - return tasks - diff --git a/pipelinerl/miniwob/rollouts.py b/pipelinerl/miniwob/rollouts.py deleted file mode 100644 index bbf68860..00000000 --- a/pipelinerl/miniwob/rollouts.py +++ /dev/null @@ -1,152 +0,0 @@ - -import asyncio -import logging -import os -import random -import time -import aiohttp -from hydra.utils import instantiate -from omegaconf import DictConfig - -from pipelinerl.async_llm import llm_async_generate, make_training_text -from pipelinerl.rollouts import RolloutResult -from pipelinerl.world import Job -from tapeagents.agent import Agent, DEFAULT -from tapeagents.core import LLMOutputParsingFailureAction, Observation, LLMCall -from tapeagents.llms.trainable import TrainableLLM -from tapeagents.remote_environment import AsyncRemoteEnvironment -from tapeagents.tools.simple_browser import PageObservation -from tapeagents.orchestrator import async_execute_agent -from tapeagents.io import save_json_tape -from examples.rl_webagent.steps import WebTape - - -logger = logging.getLogger(__name__) - - -def tape_contains_an_error(tape: WebTape) -> bool: - """ - Returns true if the tape ends with an error, ie if one of the following is true: - - the last step is an LLMOutputParsingFailureAction - - the tape metadata has an error - - the last step is a PageObservation with an error - """ - return ( - isinstance(tape.steps[-1], LLMOutputParsingFailureAction) - or tape.metadata.result.get("error") is not None - or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error) - ) - - -async def generate_miniwob_rollout( - cfg: DictConfig, - llm: TrainableLLM, - problem: dict, - session: aiohttp.ClientSession, -) -> RolloutResult: - # choose a random environment server - # Generate environment - # Generate TapeAgent - # run the agent - # get llm calls from tape - # compute rewards - # get training text from llm calls - - start_time = time.time() - - # (1) Choose a random environment server - env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] - # choose the env job randomly - env_job = random.choice(env_jobs) - assert env_job.port is not None - env_job_url = f"http://{env_job.hostname}:{env_job.port}" - - # (2) Generate environment, TapeAgent, and run them to get a Tape - environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore - async with environment.acontext(session, wait_for_env=True) as env: - start_attempts = cfg.start_attempts - t = time.perf_counter() - while True: - try: - tape_dict, _ = await env.start_task(problem) - break - except Exception as e: - start_attempts -= 1 - if start_attempts <= 0: - raise e - logger.warning(f"Failed to start task, retry after 5 seconds: {e}") - await asyncio.sleep(5) - logger.info(f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds") - tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object - t = time.perf_counter() - try: - actions = await env.a_actions() - tools_description = await env.a_tools_description() - logger.debug(f"Available tools: {tools_description}") - agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) - agent.llms = {DEFAULT: llm} - tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) - except Exception as e: - logger.error(f"Error occurred while running agent: {e}") - tape.metadata.result = {"execution_time": time.perf_counter() - t} - - # save the tape as we go - if cfg.save_tapes: - save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id) - - # (3) Compute rewards - last_obs = [step for step in tape if isinstance(step, Observation)][-1] - # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 - # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L183 - # Let's take directly the RAW_REWARD_GLOBAL from the metadata - # raw_reward = last_obs.metadata.other.get("reward", 0.0) - raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) - no_error = not tape_contains_an_error(tape) - # get the number of LLMOutputParsingFailureAction in the tape - n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) - # get the number of PageObservation steps in the tape - n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)]) - - reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 - - # (3) Get LLM calls from Tape - llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] - n_llm_calls = len(llm_calls) - llm_calls: list[LLMCall] = [ - LLMCall(**step.metadata.other["llm_call"]) if isinstance(step.metadata.other["llm_call"], dict) - else step.metadata.other["llm_call"] - for step in llm_calls - ] - - # (4) # For each LLM interaction in the tape, make a training example. - all_finished = 0 - prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] - output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] - training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] - for text in training_texts: - text.reward = reward - all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 - - latency = time.time() - start_time - - metrics = { - "reward": reward, - "success": 1 if reward > 0.5 else 0, - "no_error": no_error, - "no_answer": 1 if reward < 0 else 0, - "overflow": 0 if all_finished else 1, - "n_llm_calls": n_llm_calls, - "n_step_errors": n_step_errors, - "n_page_observations": n_page_observations, - "n_steps": len(tape.steps), - } - - return RolloutResult( - training_texts=training_texts, - metrics=metrics, - latency=latency, - dataset_name=problem["dataset"], - prompt_tokens=prompt_tokens, - output_tokens=output_tokens, - ) - diff --git a/pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja b/pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja deleted file mode 100644 index a3bc9f02..00000000 --- a/pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja +++ /dev/null @@ -1,120 +0,0 @@ -{{- bos_token }} -{%- if custom_tools is defined %} - {%- set tools = custom_tools %} -{%- endif %} -{%- if not tools_in_user_message is defined %} - {#- Llama 3.1 doesn't pass all tests if the tools are in the system prompt #} - {%- set tools_in_user_message = true %} -{%- endif %} -{%- if not date_string is defined %} - {%- if strftime_now is defined %} - {%- set date_string = strftime_now("%d %b %Y") %} - {%- else %} - {%- set date_string = "26 Jul 2024" %} - {%- endif %} -{%- endif %} -{%- if not tools is defined %} - {%- set tools = none %} -{%- endif %} - -{#- This block extracts the system message, so we can slot it into the right place. #} -{%- if messages[0]['role'] == 'system' %} - {%- if messages[0]['content'] is string %} - {%- set system_message = messages[0]['content']|trim %} - {%- else %} - {%- set system_message = messages[0]['content'][0]['text']|trim %} - {%- endif %} - {%- set messages = messages[1:] %} -{%- else %} - {%- if tools is not none %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} - {%- else %} - {%- set system_message = "" %} - {%- endif %} -{%- endif %} - -{#- System message #} -{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} -{%- if tools is not none %} - {{- "Environment: ipython\n" }} -{%- endif %} -{{- "Cutting Knowledge Date: December 2023\n" }} -{{- "Today Date: " + date_string + "\n\n" }} -{%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} -{%- endif %} -{{- system_message }} -{{- "<|eot_id|>" }} - -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and not tools is none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} - {{- "Given the following functions, please respond with a JSON for a function call " }} - {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- first_user_message + "<|eot_id|>"}} -{%- endif %} - -{%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} - {%- if message['content'] is string %} - {{- message['content'] | trim}} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- '<|eot_id|>' }} - {%- elif 'tool_calls' in message %} - {%- if not message.tool_calls|length == 1 %} - {{- raise_exception("This model only supports single tool-calls at once!") }} - {%- endif %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} - {{- '{"name": "' + tool_call.name + '", ' }} - {{- '"parameters": ' }} - {{- tool_call.arguments | tojson }} - {{- "}" }} - {{- "<|eot_id|>" }} - {%- elif message.role == "tool" or message.role == "ipython" %} - {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is string %} - {{- { "output": message.content } | tojson }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- { "output": content['text'] } | tojson }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- "<|eot_id|>" }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} -{%- endif %} \ No newline at end of file diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 65e29b4b..cd34b54d 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -160,7 +160,18 @@ def preprocess_dataset( entry["step_index"] = entry["metadata"]["step_index"] if not isinstance(tokenizer.eos_token_id, int): raise ValueError(f"Tokenizer {tokenizer} does not have an eos_token_id") - dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + try: + dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + except Exception as e: + logger.error(f"Error in populate_rl_data: {e}") + logger.error(f"Data: {data}") + logger.error(f"Dataset: {dataset}") + logger.error(f"Tokenizer: {tokenizer}") + logger.error(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}") + logger.error(f"RL config: {rl_config}") + logger.error(f"LLM: {llm}") + logger.error(f"Seq length: {seq_length}") + raise e return dataset diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index 2b0a252c..a6467271 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -239,6 +239,9 @@ def calculate_stats(stats: List | Dict[Any, Any]) -> Dict[str, float]: if not isinstance(stats, list): raise TypeError(f"Expected stats to be a list, got {type(stats)}") + if len(stats) == 0: + return {} + aggregated_stats = { "max": float(max(stats)), "min": float(min(stats)), diff --git a/pipelinerl/world.py b/pipelinerl/world.py index f41714e4..cc23afd0 100644 --- a/pipelinerl/world.py +++ b/pipelinerl/world.py @@ -188,7 +188,10 @@ def _place_pipeline_stages(self, cfg): self.add_job(kind="preprocessor", replica_idx=worker_idx, node_rank=node, gpus=[], cpu_heavy=True) def _place_environments(self, cfg): - for worker_idx in range(cfg.world.env_replicas): + # Scale environment servers to be the same as llm servers + env_replicas_per_actor = getattr(cfg.world, "env_replicas_per_actor", 1) + total_env_replicas = cfg.world.replicas * self.llms_per_actor * env_replicas_per_actor + for worker_idx in range(total_env_replicas): node = self.get_least_busy_node() envs_at_node = len([job for job in self.job_map[node] if job.kind == "environment"]) self.add_job(