Efficient World Models with Context-Aware Tokenization
Vincent Micheli*, Eloi Alonso*, François Fleuret
TL;DR Δ-IRIS is a reinforcement learning agent trained in the imagination of its world model.
delta-iris.mp4
pip install pip==23.0- Install dependencies:
pip install -r requirements.txt - Warning: Atari ROMs will be downloaded with the Atari dependencies, which means that you acknowledge that you have the license to use them.
Crafter:
python src/main.pyThe run will be located in outputs/YYYY-MM-DD/hh-mm-ss/.
By default, logs are synced to weights & biases, set wandb.mode=disabled to turn logging off.
Atari:
python src/main.py env=atari params=atari env.train.id=BreakoutNoFrameskip-v4Note that this Atari configuration achieves slightly higher aggregate metrics than those reported in the paper. Here is the updated table of results.
- All configuration files are located in
config/, the main configuration file isconfig/trainer.yaml. - The simplest way to customize the configuration is to edit these files directly.
- Please refer to Hydra for more details regarding configuration management.
Each new run is located in outputs/YYYY-MM-DD/hh-mm-ss/. This folder is structured as:
outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│ │ last.pt
│ │ optimizer.pt
│ │ ...
│ │
│ └── dataset
│ │
│ └─ train
│ │ info.pt
│ │ ...
│ │
│ └─ test
│ │ info.pt
│ │ ...
│
└─── config
│ │ trainer.yaml
│ │ ...
│
└─── media
│ │
│ └── episodes
│ │ ...
│ │
│ └── reconstructions
│ │ ...
│
└─── scripts
│ │ resume.sh
│ │ play.sh
│
└─── src
│ │ main.py
│ │ ...
│
└─── wandb
│ ...checkpoints: contains the last checkpoint of the model, its optimizer and the dataset.media:episodes: contains train / test episodes for visualization purposes.reconstructions: contains original frames alongside their reconstructions with the autoencoder.
scripts: from the run folder, you can use the following scripts.resume.sh: Launch./scripts/resume.shto resume a training run that crashed.play.sh: Tool to visualize the agent and interact with the world model.- Launch
./scripts/play.shto watch the agent play live in the environment. - Launch
./scripts/play.sh -wto play live in the world model. Note that for faster interaction, the memory of the world model is flushed after a few seconds. - Launch
./scripts/play.sh -ato watch the agent play live in the world model. Note that for faster interaction, the memory of the world model is flushed after a few seconds. - Launch
./scripts/play.sh -eto visualize the episodes contained inmedia/episodes. - Add the flag
-hto display a header with additional information.
- Launch
An agent checkpoint (Crafter 5M frames) is available on the Hugging Face Hub.
To visualize the agent or play in its world model:
- Download the checkpoint
last.pt - Create a
checkpointsdirectory - Copy the checkpoint to
checkpoints/last.pt - Run
./scripts/play.shwith the flags of your choice as described above.
If you find this code or paper useful, please use the following reference:
@inproceedings{
micheli2024efficient,
title={Efficient World Models with Context-Aware Tokenization},
author={Vincent Micheli and Eloi Alonso and François Fleuret},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=BiWIERWBFX}
}
- https://github.com/pytorch/pytorch
- https://github.com/karpathy/minGPT
- You might also want to check out our codebase for IRIS