Skip to content

Commit aa1c591

Browse files
[webgl] Modularize several unary kernels (#4260)
* sigmoid * softplus * asin * acos * atan * sinh * cosh * tanh * asinh * acosh * atanh * erf
1 parent 55a6e12 commit aa1c591

File tree

15 files changed

+447
-158
lines changed

15 files changed

+447
-158
lines changed

tfjs-backend-webgl/src/backend_webgl.ts

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,66 +1530,6 @@ export class MathBackendWebGL extends KernelBackend {
15301530
return this.compileAndRun<Tensor>(program, inputs) as T;
15311531
}
15321532

1533-
sigmoid<T extends Tensor>(x: T): T {
1534-
const program = new UnaryOpProgram(x.shape, unary_op.SIGMOID);
1535-
return this.compileAndRun(program, [x]);
1536-
}
1537-
1538-
softplus<T extends Tensor>(x: T): T {
1539-
const program = new UnaryOpProgram(x.shape, unary_op.SOFTPLUS);
1540-
return this.compileAndRun(program, [x]);
1541-
}
1542-
1543-
asin<T extends Tensor>(x: T): T {
1544-
const program = new UnaryOpProgram(x.shape, unary_op.ASIN);
1545-
return this.compileAndRun(program, [x]);
1546-
}
1547-
1548-
acos<T extends Tensor>(x: T): T {
1549-
const program = new UnaryOpProgram(x.shape, unary_op.ACOS);
1550-
return this.compileAndRun(program, [x]);
1551-
}
1552-
1553-
atan<T extends Tensor>(x: T): T {
1554-
const program = new UnaryOpProgram(x.shape, unary_op.ATAN);
1555-
return this.compileAndRun(program, [x]);
1556-
}
1557-
1558-
sinh<T extends Tensor>(x: T): T {
1559-
const program = new UnaryOpProgram(x.shape, unary_op.SINH);
1560-
return this.compileAndRun(program, [x]);
1561-
}
1562-
1563-
cosh<T extends Tensor>(x: T): T {
1564-
const program = new UnaryOpProgram(x.shape, unary_op.COSH);
1565-
return this.compileAndRun(program, [x]);
1566-
}
1567-
1568-
tanh<T extends Tensor>(x: T): T {
1569-
const program = new UnaryOpProgram(x.shape, unary_op.TANH);
1570-
return this.compileAndRun(program, [x]);
1571-
}
1572-
1573-
asinh<T extends Tensor>(x: T): T {
1574-
const program = new UnaryOpProgram(x.shape, unary_op.ASINH);
1575-
return this.compileAndRun(program, [x]);
1576-
}
1577-
1578-
acosh<T extends Tensor>(x: T): T {
1579-
const program = new UnaryOpProgram(x.shape, unary_op.ACOSH);
1580-
return this.compileAndRun(program, [x]);
1581-
}
1582-
1583-
atanh<T extends Tensor>(x: T): T {
1584-
const program = new UnaryOpProgram(x.shape, unary_op.ATANH);
1585-
return this.compileAndRun(program, [x]);
1586-
}
1587-
1588-
erf<T extends Tensor>(x: T): T {
1589-
const program = new UnaryOpProgram(x.shape, unary_op.ERF);
1590-
return this.compileAndRun(program, [x]);
1591-
}
1592-
15931533
step<T extends Tensor>(x: T, alpha: number): T {
15941534
const program = new UnaryOpProgram(x.shape, unary_op.STEP(alpha));
15951535
return this.compileAndRun(program, [x]);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Acos, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
import {CHECK_NAN_SNIPPET} from '../unaryop_gpu';
22+
23+
const ACOS = CHECK_NAN_SNIPPET + `
24+
if (abs(x) > 1.) {
25+
return NAN;
26+
}
27+
return acos(x);
28+
`;
29+
30+
export const acos = unaryKernelFunc({opSnippet: ACOS});
31+
32+
export const acosConfig: KernelConfig = {
33+
kernelName: Acos,
34+
backendName: 'webgl',
35+
kernelFunc: acos,
36+
};
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
/**
3+
* @license
4+
* Copyright 2020 Google LLC. All Rights Reserved.
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
* =============================================================================
17+
*/
18+
19+
import {Acosh, KernelConfig} from '@tensorflow/tfjs-core';
20+
21+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
22+
import {CHECK_NAN_SNIPPET} from '../unaryop_gpu';
23+
24+
const ACOSH = CHECK_NAN_SNIPPET + `
25+
if (x < 1.0) return NAN;
26+
return log(x + sqrt(x * x - 1.0));`;
27+
28+
export const acosh = unaryKernelFunc({opSnippet: ACOSH});
29+
30+
export const acoshConfig: KernelConfig = {
31+
kernelName: Acosh,
32+
backendName: 'webgl',
33+
kernelFunc: acosh,
34+
};
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Asin, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
import {CHECK_NAN_SNIPPET} from '../unaryop_gpu';
22+
23+
const ASIN = CHECK_NAN_SNIPPET + `
24+
if (abs(x) > 1.) {
25+
return NAN;
26+
}
27+
return asin(x);
28+
`;
29+
30+
export const asin = unaryKernelFunc({opSnippet: ASIN});
31+
32+
export const asinConfig: KernelConfig = {
33+
kernelName: Asin,
34+
backendName: 'webgl',
35+
kernelFunc: asin,
36+
};
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Asinh, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
import {CHECK_NAN_SNIPPET} from '../unaryop_gpu';
22+
23+
const ASINH = CHECK_NAN_SNIPPET + `return log(x + sqrt(x * x + 1.0));`;
24+
25+
export const asinh = unaryKernelFunc({opSnippet: ASINH});
26+
27+
export const asinhConfig: KernelConfig = {
28+
kernelName: Asinh,
29+
backendName: 'webgl',
30+
kernelFunc: asinh,
31+
};
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Atan, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
import {CHECK_NAN_SNIPPET} from '../unaryop_gpu';
22+
23+
const ATAN = CHECK_NAN_SNIPPET + `
24+
return atan(x);
25+
`;
26+
27+
export const atan = unaryKernelFunc({opSnippet: ATAN});
28+
29+
export const atanConfig: KernelConfig = {
30+
kernelName: Atan,
31+
backendName: 'webgl',
32+
kernelFunc: atan,
33+
};
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Atanh, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
import {CHECK_NAN_SNIPPET} from '../unaryop_gpu';
22+
23+
const ATANH = CHECK_NAN_SNIPPET + `
24+
if ((x < -1.0) || (x > 1.0)) return NAN;
25+
return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
26+
27+
export const atanh = unaryKernelFunc({opSnippet: ATANH});
28+
29+
export const atanhConfig: KernelConfig = {
30+
kernelName: Atanh,
31+
backendName: 'webgl',
32+
kernelFunc: atanh,
33+
};
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Cosh, KernelConfig} from '@tensorflow/tfjs-core';
19+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
20+
21+
const COSH = `
22+
float e2x = exp(-x);
23+
return (e2x + 1.0 / e2x) / 2.0;
24+
`;
25+
26+
export const cosh = unaryKernelFunc({opSnippet: COSH});
27+
28+
export const coshConfig: KernelConfig = {
29+
kernelName: Cosh,
30+
backendName: 'webgl',
31+
kernelFunc: cosh,
32+
};
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util, Erf, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
22+
const ERF = `
23+
// Error function is calculated approximately with elementary function.
24+
// See "Handbook of Mathematical Functions with Formulas,
25+
// Graphs, and Mathematical Tables", Abramowitz and Stegun.
26+
float p = ${backend_util.ERF_P};
27+
float a1 = ${backend_util.ERF_A1};
28+
float a2 = ${backend_util.ERF_A2};
29+
float a3 = ${backend_util.ERF_A3};
30+
float a4 = ${backend_util.ERF_A4};
31+
float a5 = ${backend_util.ERF_A5};
32+
33+
float sign = sign(x);
34+
x = abs(x);
35+
float t = 1.0 / (1.0 + p * x);
36+
return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
37+
`;
38+
39+
export const erf = unaryKernelFunc({opSnippet: ERF});
40+
41+
export const erfConfig: KernelConfig = {
42+
kernelName: Erf,
43+
backendName: 'webgl',
44+
kernelFunc: erf,
45+
};
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, Sigmoid} from '@tensorflow/tfjs-core';
19+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
20+
21+
const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;
22+
23+
export const sigmoid = unaryKernelFunc({opSnippet: SIGMOID});
24+
25+
export const sigmoidConfig: KernelConfig = {
26+
kernelName: Sigmoid,
27+
backendName: 'webgl',
28+
kernelFunc: sigmoid,
29+
};

0 commit comments

Comments
 (0)