Skip to content

Commit 69f5803

Browse files
Add Prod kernel to WASM backend. (#4138)
FEATURE Co-authored-by: Ann Yuan <[email protected]>
1 parent 5154050 commit 69f5803

File tree

4 files changed

+187
-1
lines changed

4 files changed

+187
-1
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/* Copyright 2020 Google LLC. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#ifdef __EMSCRIPTEN__
16+
#include <emscripten.h>
17+
#endif
18+
19+
#include <cstddef>
20+
21+
#include "src/cc/backend.h"
22+
#include "src/cc/util.h"
23+
24+
namespace {
25+
26+
template <typename T>
27+
void prod(const size_t x_id, const size_t reduce_size, const size_t out_id) {
28+
auto& x_info = tfjs::backend::get_tensor_info(x_id);
29+
auto& out_info = tfjs::backend::get_tensor_info_out(out_id);
30+
31+
const T* x_buf = reinterpret_cast<const T *>(x_info.memory_offset);
32+
const size_t x_size = x_info.size;
33+
34+
T* out_buf = reinterpret_cast<T *>(out_info.memory_offset);
35+
const size_t out_size = out_info.size;
36+
37+
const T* x_offset = x_buf;
38+
39+
for (size_t i = 0; i < out_size; ++i) {
40+
const size_t offset = i * reduce_size;
41+
T product = 1;
42+
43+
const T* x_iter_end = x_offset + reduce_size;
44+
45+
for (const T* x = x_offset; x < x_iter_end; ++x) {
46+
product *= (*x);
47+
}
48+
49+
x_offset += reduce_size;
50+
out_buf[i] = product;
51+
}
52+
}
53+
54+
} // namespace
55+
56+
namespace tfjs {
57+
namespace wasm {
58+
59+
// We use C-style API to interface with Javascript.
60+
extern "C" {
61+
62+
#ifdef __EMSCRIPTEN__
63+
EMSCRIPTEN_KEEPALIVE
64+
#endif
65+
void Prod(const size_t x_id, const size_t reduce_size,
66+
const DType dtype, const size_t out_id) {
67+
switch (dtype) {
68+
case DType::float32:
69+
prod<float>(x_id, reduce_size, out_id);
70+
break;
71+
case DType::int32:
72+
prod<int32_t>(x_id, reduce_size, out_id);
73+
break;
74+
case DType::boolean:
75+
prod<bool>(x_id, reduce_size, out_id);
76+
break;
77+
default:
78+
util::warn("Prod failed. Unknown dtype %d", dtype);
79+
}
80+
}
81+
82+
83+
} // extern "C"
84+
} // namespace wasm
85+
} // namespace tfjs
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util, KernelConfig, KernelFunc, Prod, ProdAttrs, ProdInputs, TensorInfo, util} from '@tensorflow/tfjs-core';
19+
20+
import {BackendWasm} from '../backend_wasm';
21+
22+
import {permuteAxesAndTranspose} from './kernel_utils';
23+
24+
import {CppDType} from './types';
25+
26+
let wasmProd: (
27+
xId: number, reduceSize: number,
28+
dtype: number, outId: number) => void;
29+
30+
function setup(backend: BackendWasm): void {
31+
wasmProd = backend.wasm.cwrap(Prod, null /*void*/, [
32+
'number',
33+
'number',
34+
'number',
35+
'number'
36+
]);
37+
}
38+
39+
function prod(args: {
40+
backend: BackendWasm,
41+
inputs: ProdInputs,
42+
attrs: ProdAttrs
43+
}): TensorInfo {
44+
const {backend, inputs, attrs} = args;
45+
const {axis, keepDims} = attrs;
46+
const {x} = inputs;
47+
const xId = backend.dataIdMap.get(x.dataId).id;
48+
let inputId = xId;
49+
let input = x;
50+
51+
const {transposed, axes, originalAxes, inputWasTransposed} =
52+
permuteAxesAndTranspose(x, axis, backend);
53+
54+
let reductionAxes = axes;
55+
if (inputWasTransposed) {
56+
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
57+
if (transposedId !== xId) {
58+
// transpose was not a no-op. We will need to dispose of this
59+
// once we are done.
60+
input = transposed;
61+
inputId = transposedId;
62+
reductionAxes = backend_util.getInnerMostAxes(
63+
reductionAxes.length, input.shape.length);
64+
}
65+
}
66+
67+
backend_util.assertAxesAreInnerMostDims(
68+
'prod', reductionAxes, input.shape.length);
69+
const [outShape, reduceShape] =
70+
backend_util.computeOutAndReduceShapes(input.shape, reductionAxes);
71+
const reduceSize = util.sizeFromShape(reduceShape);
72+
73+
const out = backend.makeOutput(outShape, input.dtype);
74+
if (util.sizeFromShape(input.shape) !== 0) {
75+
const outId = backend.dataIdMap.get(out.dataId).id;
76+
wasmProd(inputId, reduceSize, CppDType[out.dtype], outId);
77+
}
78+
79+
if (inputWasTransposed) {
80+
// dispose of the transposed tensor.
81+
backend.disposeData(transposed.dataId);
82+
}
83+
84+
if (keepDims) {
85+
// reshape
86+
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
87+
out.shape = newShape;
88+
}
89+
90+
return out;
91+
}
92+
93+
export const prodConfig: KernelConfig = {
94+
kernelName: Prod,
95+
backendName: 'wasm',
96+
setupFunc: setup,
97+
kernelFunc: prod as {} as KernelFunc
98+
};

tfjs-backend-wasm/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import {onesLikeConfig} from './kernels/OnesLike';
7070
import {padV2Config} from './kernels/PadV2';
7171
import {powConfig} from './kernels/Pow';
7272
import {preluConfig} from './kernels/Prelu';
73+
import {prodConfig} from './kernels/Prod';
7374
import {reluConfig} from './kernels/Relu';
7475
import {relu6Config} from './kernels/Relu6';
7576
import {reshapeConfig} from './kernels/Reshape';
@@ -149,6 +150,7 @@ const kernelConfigs: KernelConfig[] = [
149150
padV2Config,
150151
powConfig,
151152
preluConfig,
153+
prodConfig,
152154
reluConfig,
153155
relu6Config,
154156
reshapeConfig,

tfjs-backend-wasm/src/setup_test.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ const TEST_FILTERS: TestFilter[] = [
376376
startsWith: 'onesLike',
377377
// Complex numbers not supported yet.
378378
excludes: ['complex'],
379-
}
379+
},
380+
{include: 'prod'},
380381
];
381382

382383
const customInclude = (testName: string) => {

0 commit comments

Comments
 (0)