Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/scripts/linux-post-script.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#!/bin/bash

yum update gcc
yum update libstdc++
if [ "$(uname)" != "Darwin" ]; then
yum update gcc
yum update libstdc++
else
brew update
brew upgrade gcc
fi
18 changes: 18 additions & 0 deletions .github/scripts/linux-pre-script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

if [ "$(uname)" != "Darwin" ]; then
yum update gcc
yum update libstdc++
else
echo $(gcc --version)
echo $(clang --version)
brew update
brew upgrade gcc
brew upgrade clang

# For OSX
# export CXXFLAGS="-march=armv8-a+fp16+sha3"
export CMAKE_OSX_ARCHITECTURES=arm64
fi

${CONDA_RUN} conda install -c conda-forge pybind11 -y
8 changes: 8 additions & 0 deletions .github/scripts/version_script.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
set TENSORDICT_BUILD_VERSION=0.8.0
echo TENSORDICT_BUILD_VERSION is set to %TENSORDICT_BUILD_VERSION%

if "%CONDA_RUN%"=="" (
echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN.
exit /b 1
)

:: Run the pip install command
%CONDA_RUN% conda install -c conda-forge pybind11 -y

@echo on

set VC_VERSION_LOWER=17
Expand Down
17 changes: 17 additions & 0 deletions .github/scripts/version_script.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#!/bin/bash

export TENSORDICT_BUILD_VERSION=0.8.0

if [ "$(uname)" == "Darwin" ]; then
# For OSX
echo $(gcc --version)
echo $(clang --version)
brew update
brew install gcc
brew install clang-build-analyzer
brew install --cask clay
brew install llvm
# brew upgrade gcc
# brew upgrade clang
# export CXXFLAGS="-march=armv8-a+fp16+sha3"
export CMAKE_OSX_ARCHITECTURES=arm64
fi

${CONDA_RUN} conda install -c conda-forge pybind11 -y
17 changes: 17 additions & 0 deletions .github/scripts/win-pre-script.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@echo off
:: Check if CONDA_RUN is set, if not, set it to a default value
if "%CONDA_RUN%"=="" (
echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN.
exit /b 1
)

:: Run the pip install command
%CONDA_RUN% conda install -c conda-forge pybind11 -y

:: Check if the installation was successful
if errorlevel 1 (
echo Failed to install cmake and pybind11.
exit /b 1
) else (
echo Successfully installed cmake and pybind11.
)
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ dependencies:
- coverage
- h5py
- orjson
- ninja
- numpy<2.0.0
3 changes: 3 additions & 0 deletions .github/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda install anaconda::cmake -y
conda install -c conda-forge pybind11 -y

#if [[ $OSTYPE == 'darwin'* ]]; then
# printf "* Installing C++ for OSX\n"
# conda install -c conda-forge cxx-compiler -y
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/rl_linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ dependencies:
- pyyaml
- scipy
- orjson
- ninja
- numpy<2.0.0
3 changes: 3 additions & 0 deletions .github/unittest/rl_linux_optdeps/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda install anaconda::cmake -y
conda install -c conda-forge pybind11 -y

#yum makecache
#yum -y install glfw-devel
#yum -y install libGLEW
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-wheels-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
include:
- repository: pytorch/tensordict
smoke-test-script: test/smoke_test.py
pre-script: .github/scripts/linux-pre-script.sh
post-script: .github/scripts/linux-post-script.sh
package-name: tensordict
name: pytorch/tensordict
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/build-wheels-m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ jobs:
include:
- repository: pytorch/tensordict
smoke-test-script: test/smoke_test.py
env-script: .github/scripts/install-deps-smoke-test.sh
pre-script: .github/scripts/linux-pre-script.sh
post-script: .github/scripts/linux-post-script.sh
package-name: tensordict
name: pytorch/tensordict
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main
Expand All @@ -44,7 +45,7 @@ jobs:
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
package-name: ${{ matrix.package-name }}
runner-type: macos-m1-stable
runner-type: macos-m2-15
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/version_script.sh
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
matrix:
include:
- repository: pytorch/tensordict
pre-script: ""
pre-script: .github/scripts/win-pre-script.bat
env-script: .github/scripts/version_script.bat
post-script: "python packaging/wheel/relocate.py"
smoke-test-script: test/smoke_test.py
Expand Down
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ possible.
Install the library as suggested in the README. For advanced features,
it is preferable to install the nightly built of pytorch.

You will need the following packages to be installed:
```bash
pip install ninja cmake pybind11 -U
```

Make sure you install tensordict in develop mode by running
```
python setup.py develop
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ include LICENSE

recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-include tensordict *.so
103 changes: 56 additions & 47 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import argparse
import distutils.command.clean
import glob
import logging
import os
import shutil
Expand All @@ -15,11 +14,23 @@
from pathlib import Path
from typing import List

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext

ROOT_DIR = Path(__file__).parent.resolve()


def get_python_executable():
# Check if we're running in a virtual environment
if "VIRTUAL_ENV" in os.environ:
# Get the virtual environment's Python executable
python_executable = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "python")
else:
# Fall back to sys.executable
python_executable = sys.executable
return python_executable


try:
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=ROOT_DIR)
Expand Down Expand Up @@ -69,7 +80,7 @@ def _get_pytorch_version(is_nightly, is_local):
return "torch>=2.7.0.dev"
if is_local:
return "torch"
return "torch>=2.6.0"
return "torch>=2.5.0"


def _get_packages():
Expand Down Expand Up @@ -99,51 +110,45 @@ def run(self):
shutil.rmtree(str(path), ignore_errors=True)


def get_extensions():
extension = CppExtension

extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3",
"-std=c++17",
"-fdiagnostics-color=always",
class CMakeExtension(Extension):
def __init__(self, name, sourcedir=""):
super().__init__(name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)


class CMakeBuild(build_ext):
def run(self):
for ext in self.extensions:
self.build_extension(ext)

def build_extension(self, ext):
# extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
is_develop = self.distribution.get_command_obj("develop").finalized
# Set the output directory based on the mode
if is_develop:
extdir = os.path.abspath(os.path.join(ROOT_DIR, "tensordict"))
else:
extdir = os.path.abspath(os.path.join(self.build_lib, "tensordict"))
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DPYTHON_EXECUTABLE={get_python_executable()}",
f"-DPython3_EXECUTABLE={get_python_executable()}",
]
}
debug_mode = os.getenv("DEBUG", "0") == "1"
if debug_mode:
logging.info("Compiling in debug mode")
extra_compile_args = {
"cxx": [
"-O0",
"-fno-inline",
"-g",
"-std=c++17",
"-fdiagnostics-color=always",
]
}
extra_link_args = ["-O0", "-g"]

this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "tensordict", "csrc")

extension_sources = {
os.path.join(extensions_dir, p)
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
}
sources = list(extension_sources)

ext_modules = [
extension(
"tensordict._C",
sources,
include_dirs=[this_dir],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,

build_args = []
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
subprocess.check_call(
["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp
)
subprocess.check_call(
["cmake", "--build", "."] + build_args, cwd=self.build_temp
)
]

return ext_modules

def get_extensions():
extensions_dir = os.path.join(ROOT_DIR, "tensordict", "csrc")
return [CMakeExtension("tensordict._C", sourcedir=extensions_dir)]


def _main(argv):
Expand Down Expand Up @@ -181,7 +186,7 @@ def _main(argv):
),
ext_modules=get_extensions(),
cmdclass={
"build_ext": BuildExtension.with_options(),
"build_ext": CMakeBuild,
"clean": clean,
},
install_requires=[
Expand Down Expand Up @@ -212,6 +217,10 @@ def _main(argv):
"Programming Language :: Python :: 3.13",
"Development Status :: 4 - Beta",
],
# include_package_data=True,
package_data={
"tensordict": ["*.so", "*.pyd"],
},
)


Expand Down
34 changes: 34 additions & 0 deletions tensordict/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
cmake_minimum_required(VERSION 3.12)
project(tensordict)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Set the Python executable to the one from your virtual environment

find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
find_package(pybind11 2.13 REQUIRED)

file(GLOB SOURCES "*.cpp")

add_library(_C MODULE ${SOURCES})

set_target_properties(_C PROPERTIES
OUTPUT_NAME "_C"
PREFIX "" # Remove 'lib' prefix
SUFFIX ".so" # Ensure correct suffix for macOS/Linux
)
set_target_properties(_C PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}"
)

target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR})
target_link_libraries(_C PRIVATE Python3::Python pybind11::module)

if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -fsanitize=address")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
endif()
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum OS X deployment version")
set(CMAKE_VERBOSE_MAKEFILE ON)
Loading