77use std:: autodiff:: autodiff;
88
99#[ autodiff( d_square3, Forward , Dual , DualOnly ) ]
10- #[ no_mangle]
11- fn squaref ( x : & f32 ) -> f32 {
12- 2.0 * x * x
13- }
14-
1510#[ autodiff( d_square2, Forward , 4 , Dual , DualOnly ) ]
16- #[ autodiff( d_square , Forward , 4 , Dual , Dual ) ]
11+ #[ autodiff( d_square1 , Forward , 4 , Dual , Dual ) ]
1712#[ no_mangle]
1813fn square ( x : & f32 ) -> f32 {
1914 x * x
2015}
2116
22- // CHECK:define internal fastcc void @diffe4square([4 x ptr] %"x'"
23- // CHECK-NEXT:invertstart:
24- // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
25- // CHECK-NEXT: %1 = load double, ptr %0, align 8, !alias.scope !15950, !noalias !15953
26- // CHECK-NEXT: %2 = fadd fast double %1, 6.000000e+00
27- // CHECK-NEXT: store double %2, ptr %0, align 8, !alias.scope !15950, !noalias !15953
28- // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 1
29- // CHECK-NEXT: %4 = load double, ptr %3, align 8, !alias.scope !15958, !noalias !15959
30- // CHECK-NEXT: %5 = fadd fast double %4, 6.000000e+00
31- // CHECK-NEXT: store double %5, ptr %3, align 8, !alias.scope !15958, !noalias !15959
32- // CHECK-NEXT: %6 = extractvalue [4 x ptr] %"x'", 2
33- // CHECK-NEXT: %7 = load double, ptr %6, align 8, !alias.scope !15960, !noalias !15961
34- // CHECK-NEXT: %8 = fadd fast double %7, 6.000000e+00
35- // CHECK-NEXT: store double %8, ptr %6, align 8, !alias.scope !15960, !noalias !15961
36- // CHECK-NEXT: %9 = extractvalue [4 x ptr] %"x'", 3
37- // CHECK-NEXT: %10 = load double, ptr %9, align 8, !alias.scope !15962, !noalias !15963
38- // CHECK-NEXT: %11 = fadd fast double %10, 6.000000e+00
39- // CHECK-NEXT: store double %11, ptr %9, align 8, !alias.scope !15962, !noalias !15963
40- // CHECK-NEXT: ret void
41- // CHECK-NEXT:}
17+ // d_sqaure2
18+ // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
19+ // CHECK-NEXT: start:
20+ // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
21+ // CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !38, !noalias !39
22+ // CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
23+ // CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !40, !noalias !41
24+ // CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
25+ // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !42, !noalias !43
26+ // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
27+ // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !44, !noalias !45
28+ // CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
29+ // CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
30+ // CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
31+ // CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
32+ // CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
33+ // CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
34+ // CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
35+ // CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
36+ // CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
37+ // CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
38+ // CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
39+ // CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
40+ // CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
41+ // CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
42+ // CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
43+ // CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
44+ // CHECK-NEXT: ret [4 x float] %19
45+ // CHECK-NEXT: }
46+
47+ // d_square3, the extra float is the original return value (x * x)
48+ // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
49+ // CHECK-NEXT: start:
50+ // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
51+ // CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !46, !noalias !47
52+ // CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
53+ // CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !48, !noalias !49
54+ // CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
55+ // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !50, !noalias !51
56+ // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
57+ // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !52, !noalias !53
58+ // CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
59+ // CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
60+ // CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
61+ // CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
62+ // CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
63+ // CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
64+ // CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
65+ // CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
66+ // CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
67+ // CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
68+ // CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
69+ // CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
70+ // CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
71+ // CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
72+ // CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
73+ // CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
74+ // CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
75+ // CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
76+ // CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
77+ // CHECK-NEXT: ret { float, [4 x float] } %21
78+ // CHECK-NEXT: }
4279
4380fn main ( ) {
4481 let x = std:: hint:: black_box ( 3.0 ) ;
4582 let output = square ( & x) ;
4683 dbg ! ( & output) ;
4784 assert_eq ! ( 9.0 , output) ;
48- dbg ! ( squaref ( & x) ) ;
85+ dbg ! ( square ( & x) ) ;
4986
5087 let mut df_dx1 = 1.0 ;
5188 let mut df_dx2 = 2.0 ;
@@ -54,7 +91,7 @@ fn main() {
5491 let [ o1, o2, o3, o4] = d_square2 ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
5592 dbg ! ( o1, o2, o3, o4) ;
5693 let [ output2, o1, o2, o3, o4] =
57- d_square ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
94+ d_square1 ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
5895 dbg ! ( o1, o2, o3, o4) ;
5996 assert_eq ! ( output, output2) ;
6097 assert ! ( ( 6.0 - o1) . abs( ) < 1e-10 ) ;
0 commit comments