99import shutil
1010import subprocess
1111import sys
12+ import warnings
1213from pathlib import Path
1314
14- from setuptools import Extension , find_packages , setup
15+ from setuptools import Command , Extension , find_packages , setup
1516from setuptools .command .build_ext import build_ext
1617
1718ROOT_DIR = Path (__file__ ).parent .resolve ()
@@ -28,11 +29,38 @@ def get_python_executable():
2829 return python_executable
2930
3031
31- class clean (distutils .command .clean .clean ):
32- def run (self ):
33- # Run default behavior first
34- distutils .command .clean .clean .run (self )
32+ def check_cmake_version ():
33+ """Check if CMake version is sufficient."""
34+ try :
35+ result = subprocess .run (
36+ ["cmake" , "--version" ], capture_output = True , text = True , check = True
37+ )
38+ version_line = result .stdout .split ("\n " )[0 ]
39+ version_str = version_line .split ()[2 ]
40+ major , minor = map (int , version_str .split ("." )[:2 ])
41+ if major < 3 or (major == 3 and minor < 18 ):
42+ warnings .warn (
43+ f"CMake version { version_str } may be too old. Recommended: 3.18+"
44+ )
45+ return True
46+ except (subprocess .CalledProcessError , FileNotFoundError , ValueError ):
47+ warnings .warn ("Could not determine CMake version" )
48+ return False
49+
50+
51+ class clean (Command ):
52+ """Custom clean command to remove tensordict extensions."""
53+
54+ description = "remove tensordict extensions and build files"
55+ user_options = []
56+
57+ def initialize_options (self ):
58+ pass
59+
60+ def finalize_options (self ):
61+ pass
3562
63+ def run (self ):
3664 # Remove tensordict extension
3765 for path in (ROOT_DIR / "tensordict" ).glob ("**/*.so" ):
3866 logging .info (f"removing '{ path } '" )
@@ -53,6 +81,9 @@ def __init__(self, name, sourcedir=""):
5381
5482class CMakeBuild (build_ext ):
5583 def run (self ):
84+ # Check CMake version before building
85+ check_cmake_version ()
86+
5687 for ext in self .extensions :
5788 self .build_extension (ext )
5889
@@ -64,6 +95,7 @@ def build_extension(self, ext):
6495 else :
6596 # For regular installs, place the extension in the build directory
6697 extdir = os .path .abspath (os .path .join (self .build_lib , "tensordict" ))
98+
6799 cmake_args = [
68100 f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={ extdir } " ,
69101 f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={ extdir } " ,
@@ -78,12 +110,21 @@ def build_extension(self, ext):
78110 os .makedirs (self .build_temp )
79111 if sys .platform == "win32" :
80112 build_args += ["--config" , "Release" ]
81- subprocess .check_call (
82- ["cmake" , ext .sourcedir ] + cmake_args , cwd = self .build_temp
83- )
84- subprocess .check_call (
85- ["cmake" , "--build" , "." , "--verbose" ] + build_args , cwd = self .build_temp
86- )
113+
114+ try :
115+ subprocess .check_call (
116+ ["cmake" , ext .sourcedir ] + cmake_args , cwd = self .build_temp
117+ )
118+ subprocess .check_call (
119+ ["cmake" , "--build" , "." , "--verbose" ] + build_args , cwd = self .build_temp
120+ )
121+ except subprocess .CalledProcessError as e :
122+ warnings .warn (
123+ f"Error building extension: { e } \n "
124+ "This might be due to missing dependencies or incompatible compiler. "
125+ "Please ensure you have CMake 3.18+ and a C++17 compatible compiler."
126+ )
127+ raise
87128
88129
89130def get_extensions ():
0 commit comments