Skip to content

Commit 8f23ac0

Browse files
Merge pull request #2137 from j2kun:poly-approx-subpipeilne
PiperOrigin-RevId: 799153545
2 parents 29b93c6 + e6e6943 commit 8f23ac0

File tree

7 files changed

+60
-12
lines changed

7 files changed

+60
-12
lines changed

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@
3535
#include "lib/Transforms/LayoutOptimization/LayoutOptimization.h"
3636
#include "lib/Transforms/LayoutPropagation/LayoutPropagation.h"
3737
#include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h"
38-
#include "lib/Transforms/LowerPolynomialEval/LowerPolynomialEval.h"
3938
#include "lib/Transforms/OperationBalancer/OperationBalancer.h"
4039
#include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
41-
#include "lib/Transforms/PolynomialApproximation/PolynomialApproximation.h"
4240
#include "lib/Transforms/PopulateScale/PopulateScale.h"
4341
#include "lib/Transforms/PropagateAnnotation/PropagateAnnotation.h"
4442
#include "lib/Transforms/SecretInsertMgmt/Passes.h"
@@ -127,10 +125,6 @@ void mlirToSecretArithmeticPipelineBuilder(
127125
convertToDataObliviousPipelineBuilder(pm);
128126
pm.addPass(createSelectRewrite());
129127
pm.addPass(createCompareToSignRewrite());
130-
pm.addPass(createPolynomialApproximation());
131-
pm.addPass(createLowerPolynomialEval());
132-
pm.addPass(createCanonicalizerPass());
133-
pm.addPass(createCSEPass());
134128

135129
// Simplify linalg ops for kernel lowering
136130
// Linalg canonicalization
@@ -154,6 +148,8 @@ void mlirToSecretArithmeticPipelineBuilder(
154148
pm.addPass(
155149
createConvertToCiphertextSemantics(convertToCiphertextSemanticsOptions));
156150

151+
mathToPolynomialApproximationBuilder(pm);
152+
157153
// Balance Operations
158154
pm.addPass(createOperationBalancer());
159155

lib/Pipelines/BUILD

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ cc_library(
1818
"@heir//lib/Dialect/Polynomial/Conversions/PolynomialToModArith",
1919
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
2020
"@heir//lib/Dialect/TOSA/Conversions/TosaToSecretArith",
21-
"@heir//lib/Transforms/CompareToSignRewrite",
2221
"@heir//lib/Transforms/ConvertIfToSelect",
2322
"@heir//lib/Transforms/ConvertSecretExtractToStaticExtract",
2423
"@heir//lib/Transforms/ConvertSecretForToStaticFor",
@@ -29,6 +28,7 @@ cc_library(
2928
"@heir//lib/Transforms/MemrefToArith:ExpandCopy",
3029
"@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration",
3130
"@heir//lib/Transforms/PolynomialApproximation",
31+
"@llvm-project//llvm:Support",
3232
"@llvm-project//mlir:AffineToStandard",
3333
"@llvm-project//mlir:AffineTransforms",
3434
"@llvm-project//mlir:ArithTransforms",
@@ -131,11 +131,9 @@ cc_library(
131131
"@heir//lib/Transforms/LayoutOptimization",
132132
"@heir//lib/Transforms/LayoutPropagation",
133133
"@heir//lib/Transforms/LinalgCanonicalizations",
134-
"@heir//lib/Transforms/LowerPolynomialEval",
135134
"@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration",
136135
"@heir//lib/Transforms/OperationBalancer",
137136
"@heir//lib/Transforms/OptimizeRelinearization",
138-
"@heir//lib/Transforms/PolynomialApproximation",
139137
"@heir//lib/Transforms/PopulateScale",
140138
"@heir//lib/Transforms/PropagateAnnotation",
141139
"@heir//lib/Transforms/SecretInsertMgmt",

lib/Pipelines/PipelineRegistration.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "lib/Transforms/ConvertSecretInsertToStaticInsert/ConvertSecretInsertToStaticInsert.h"
99
#include "lib/Transforms/ConvertSecretWhileToStaticFor/ConvertSecretWhileToStaticFor.h"
1010
#include "lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.h"
11+
#include "lib/Transforms/LowerPolynomialEval/LowerPolynomialEval.h"
1112
#include "lib/Transforms/MemrefToArith/MemrefToArith.h"
13+
#include "lib/Transforms/PolynomialApproximation/PolynomialApproximation.h"
1214
#include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
1315
#include "mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project
1416
#include "mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" // from @llvm-project
@@ -91,6 +93,13 @@ void tosaPipelineBuilder(OpPassManager& manager, bool unroll) {
9193
manager.addPass(createSymbolDCEPass());
9294
}
9395

96+
void mathToPolynomialApproximationBuilder(OpPassManager& pm) {
97+
pm.addPass(createPolynomialApproximation());
98+
pm.addPass(createLowerPolynomialEval());
99+
pm.addPass(createCanonicalizerPass());
100+
pm.addPass(createCSEPass());
101+
}
102+
94103
void polynomialToLLVMPipelineBuilder(OpPassManager& manager) {
95104
// Poly
96105
manager.addPass(createElementwiseToAffine());

lib/Pipelines/PipelineRegistration.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
#ifndef LIB_PIPELINES_PIPELINEREGISTRATION_H_
22
#define LIB_PIPELINES_PIPELINEREGISTRATION_H_
33

4-
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
5-
#include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project
6-
#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project
4+
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
5+
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
6+
#include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project
7+
#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project
78

89
namespace mlir::heir {
910

1011
void tosaToLinalg(OpPassManager& manager);
1112

1213
void oneShotBufferize(OpPassManager& manager);
1314

15+
void mathToPolynomialApproximationBuilder(OpPassManager& pm);
16+
1417
struct TosaToArithOptions : public PassPipelineOptions<TosaToArithOptions> {
1518
PassOptions::Option<bool> unroll{*this, "full-unroll",
1619
llvm::cl::desc("Full unroll all loops."),
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
load("//bazel:lit.bzl", "glob_lit_tests")
2+
3+
package(default_applicable_licenses = ["@heir//:license"])
4+
5+
glob_lit_tests(
6+
name = "all_tests",
7+
data = ["@heir//tests:test_utilities"],
8+
driver = "@heir//tests:run_lit.sh",
9+
test_file_exts = ["mlir"],
10+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: heir-opt --math-to-polynomial-approximation %s | FileCheck %s
2+
3+
// CHECK: @test_maximumf
4+
func.func @test_maximumf(%x: tensor<10xf32>) -> tensor<10xf32> {
5+
// CHECK-NOT: arith.maximumf
6+
// CHECK-NOT: polynomial.eval
7+
8+
// CHECK-DAG: arith.constant dense<0.03831
9+
// CHECK-DAG: arith.constant dense<5.0{{0*}}e-01> : tensor<10xf32>
10+
// CHECK-DAG: arith.constant dense<0.9370
11+
// CHECK-DAG: arith.constant dense<-0.5062
12+
// CHECK: arith.mulf
13+
// CHECK: arith.addf
14+
// CHECK: arith.mulf
15+
// CHECK: arith.addf
16+
// CHECK: arith.mulf
17+
// CHECK: arith.addf
18+
// CHECK: arith.mulf
19+
// CHECK: arith.addf
20+
// CHECK: arith.mulf
21+
// CHECK: arith.addf
22+
// CHECK: return
23+
%c0 = arith.constant dense<0.0> : tensor<10xf32>
24+
%0 = arith.maximumf %x, %c0 : tensor<10xf32>
25+
return %0 : tensor<10xf32>
26+
}

tools/heir-opt.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,12 @@ int main(int argc, char** argv) {
457457
"Transforms a native program to data-oblivious program",
458458
convertToDataObliviousPipelineBuilder);
459459

460+
PassPipelineRegistration<>(
461+
"math-to-polynomial-approximation",
462+
"Approximate math operations that cannot be expressed in FHE using "
463+
"polynomial approximations.",
464+
mathToPolynomialApproximationBuilder);
465+
460466
return asMainReturnCode(
461467
MlirOptMain(argc, argv, "HEIR Pass Driver", registry));
462468
}

0 commit comments

Comments
 (0)