Skip to content

Conversation

@HeyDavid633
Copy link
Contributor

@HeyDavid633 HeyDavid633 commented Aug 30, 2025

PR Category

Feature Enhancement

Description

graph_net.torch.test_compiler支持后端使用 BladeDISC 编译器,即支持配置 --compiler "bladedisc",读取GraphNet/samples目录下的子图,可成功执行并获得正确的评测结果。

以 Bert 为例 Optimize and Inference BERT with TorchBlade,主要执行流程为:

  1. torch.jit.tracetorch.jit.script 把 PyTorch 模型转成 TorchScript
  2. 用 BladeDISC 的 torch_blade.optimize 进行编译优化,生成编译后的 compiled_model
  3. 组合起来编译后的模型和输入参数 compiled_model (input) 以执行前向
    其中,结合 torch.jit.tracetorch.jit.script 进行编译优化的过程可以简单抽象为:
# allow_tracing=True   using torch.jit.trace(model, inputs)
compiled_model = torch_blade.optimize(model, allow_tracing=True, model_inputs=tuple(inputs))
# allow_tracing=False  using torch.jit.script(model)  在本例中的尝试
compiled_model = torch_blade.optimize(model, allow_tracing=False)

在本次集成中使用的是torch.jit.trace; 使用官方镜像 bladedisc/bladedisc:latest-runtime-torch1.12.0-cu113 以快速获取编译器性能测评数据。详细技术报告请参考 BladeDISC_tech_report.md

测试报告:

  • BladeDISC for torch (import torch_blade)在(2025.08.30)现有 /samples 中没有出现一整类都无法运行的情况。
  • 对于 /samples/cosyvoice 下全部的模型,在GPU A100-SXM-40GB 上批量性能测试可见 BladeDISC_batch_test.txt
  • 对于 /samples 每个类别测试一种模型,测试报告可见 BladeDISC_validation_report.txt,性能速览如下:
Model Eager (ms) Compiled (ms)
cosyvoice/CosyVoice-300M 8.4000 8.3600
mmpose/2xmspn_50 17.1000 14.1000
mmseg/ANN_R50 21.7000 21.8000
nemo/parakeet-ctc-0.6b 55.3000 54.4000
torchaudio/convtasnet_base_libri2mix 99.4000 99.6000
torchgeometric/LINKX 1.0300 0.7280
timm/darknet17 2.1500 2.1300
torchvision/deeplabv3_resnet50 8.4300 7.6200
transformers-auto-model/hf-tiny-model-private_tiny-random-AltCLIPModel 6.0000 4.4200
ultralytics/yolo11l-cls 17.6000 14.8000

@paddle-bot
Copy link

paddle-bot bot commented Aug 30, 2025

Thanks for your contribution!

}


class BladeDISCBackend(GraphCompilerBackend):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把基类GraphCompilerBackend 放到单独的 graph_compiler_backend.py 里。
把BladeDISCBackend 实现到 blade_disc_backend.py 里。

import torch

try:
import torch_tensorrt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些代码不应该放到基类里,这不是基类关心的事情

Comment on lines 56 to 59
if cls == InductorBackend:
return InductorBackend()
elif cls == TensorRTBackend:
return TensorRTBackend()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改回去
这些地方原本是工厂模式,结果又变成这种软件质量很差的代码。

import torch_blade
except ImportError:
torch_blade = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class BladeDISCCompiledModule(torch.nn.Module):

    def __init__(self, module):
        super().__init__()
        self.module = module
        self.counter = 0

    def forward(self, *args, **kwargs):
        if self.counter == 0:
            self.module = self.compile(self.module, *args, **kwargs)
        ret = self.module(*args, **kwargs)
        self.counter += 1
        return ret

    def compile(self, module, *args, **kwargs):
        dummy_input = tuple([*args, *kwargs.values()])
        return torch_blade.optimize(
                module, allow_tracing=True, model_inputs=dummy_input
        )

Comment on lines 11 to 12
def __init__(self, input_dict):
self.input_dict = input_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉

torch.cuda.synchronize()


registry_backend = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加回来


registry_backend = {
"inductor": InductorBackend(),
"tensorrt": TensorRTBackend(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增"bladedisc: BladeDISCBackend()"

self.input_dict = input_dict

def __call__(self, model):
torch_config = torch_blade.config.Config()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return BladeDISCCompiledModule(model)

Comment on lines 17 to 20
from .graph_compiler_backend import GraphCompilerBackend
from .inductor_backend import InductorBackend
from .tensorrt_backend import TensorRTBackend
from .blade_disc_backend import BladeDISCBackend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好直接用绝对路径

@fangfangssj
Copy link
Contributor

fangfangssj commented Sep 1, 2025

] = f"TensorRT {torch_tensorrt.version}"
可以在这里附近加一下关于BladeDISC的版本信息嘛

@lixinqi lixinqi merged commit d82ce76 into PaddlePaddle:develop Sep 1, 2025
3 checks passed
@HeyDavid633 HeyDavid633 changed the title [Feature Enhancement] add BladeDISC for compiler backend 【Hackathon 9th No.100】[Feature Enhancement] add BladeDISC for compiler backend Sep 3, 2025
JewelRoam pushed a commit to JewelRoam/GraphNet that referenced this pull request Oct 29, 2025
…e#242)

* [Feature Enhancement] add BladeDISC for compiler backend

* Fix test_compiler

* Fix test_compiler

* Fix test_compiler and blade_disc_backend

* update bladedisc version info
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants