Skip to content

Conversation

@fangfangssj
Copy link
Contributor

@fangfangssj fangfangssj commented Sep 4, 2025

PR Category

Feature Enhancement

Description

graph_net.torch.test_compiler支持后端使用Xla编译器
需要配置 --compiler "xla",同时需要设置环境变量PJRT_DEVICE来指定cuda或者cpu(不设置的话自动检测,优先使用cuda)

测试版本 torch_xla - 2.7 (CUDA 12.6 + Python 3.11)whl包,torch_xla 的版本和torch的版本是一一对应的(CUDA版本也是需要对应的,官方编译的现在大部分都是CPU的),在最新版本的xla中,已经移除了对CUDA的支持,只支持TPU,CPU后端

对模型中创建张量时硬编码的设备处理上,使用抽象语法树来替代字符串提取替换,同时新增device参数
对于xla编译器,需要把所有张量都创建在xla设备上,xla仅支持转化到cpu上
对比速度的时候,前向运行在cuda上,编译过的模型运行在xla(实际设备为cuda)上

@paddle-bot
Copy link

paddle-bot bot commented Sep 4, 2025

Thanks for your contribution!

@lixinqi lixinqi merged commit c0e7b6d into PaddlePaddle:develop Sep 5, 2025
3 checks passed
JewelRoam pushed a commit to JewelRoam/GraphNet that referenced this pull request Oct 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants