|
| 1 | +/** |
| 2 | + * @license |
| 3 | + * Copyright 2020 Google LLC. All Rights Reserved. |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ============================================================================= |
| 16 | + */ |
| 17 | + |
| 18 | +import {backend_util, KernelConfig, KernelFunc, Prod, ProdAttrs, ProdInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; |
| 19 | + |
| 20 | +import {BackendWasm} from '../backend_wasm'; |
| 21 | + |
| 22 | +import {permuteAxesAndTranspose} from './kernel_utils'; |
| 23 | + |
| 24 | +import {CppDType} from './types'; |
| 25 | + |
| 26 | +let wasmProd: ( |
| 27 | + xId: number, reduceSize: number, |
| 28 | + dtype: number, outId: number) => void; |
| 29 | + |
| 30 | +function setup(backend: BackendWasm): void { |
| 31 | + wasmProd = backend.wasm.cwrap(Prod, null /*void*/, [ |
| 32 | + 'number', |
| 33 | + 'number', |
| 34 | + 'number', |
| 35 | + 'number' |
| 36 | + ]); |
| 37 | +} |
| 38 | + |
| 39 | +function prod(args: { |
| 40 | + backend: BackendWasm, |
| 41 | + inputs: ProdInputs, |
| 42 | + attrs: ProdAttrs |
| 43 | +}): TensorInfo { |
| 44 | + const {backend, inputs, attrs} = args; |
| 45 | + const {axis, keepDims} = attrs; |
| 46 | + const {x} = inputs; |
| 47 | + const xId = backend.dataIdMap.get(x.dataId).id; |
| 48 | + let inputId = xId; |
| 49 | + let input = x; |
| 50 | + |
| 51 | + const {transposed, axes, originalAxes, inputWasTransposed} = |
| 52 | + permuteAxesAndTranspose(x, axis, backend); |
| 53 | + |
| 54 | + let reductionAxes = axes; |
| 55 | + if (inputWasTransposed) { |
| 56 | + const transposedId = backend.dataIdMap.get(transposed.dataId).id; |
| 57 | + if (transposedId !== xId) { |
| 58 | + // transpose was not a no-op. We will need to dispose of this |
| 59 | + // once we are done. |
| 60 | + input = transposed; |
| 61 | + inputId = transposedId; |
| 62 | + reductionAxes = backend_util.getInnerMostAxes( |
| 63 | + reductionAxes.length, input.shape.length); |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + backend_util.assertAxesAreInnerMostDims( |
| 68 | + 'prod', reductionAxes, input.shape.length); |
| 69 | + const [outShape, reduceShape] = |
| 70 | + backend_util.computeOutAndReduceShapes(input.shape, reductionAxes); |
| 71 | + const reduceSize = util.sizeFromShape(reduceShape); |
| 72 | + |
| 73 | + const out = backend.makeOutput(outShape, input.dtype); |
| 74 | + if (util.sizeFromShape(input.shape) !== 0) { |
| 75 | + const outId = backend.dataIdMap.get(out.dataId).id; |
| 76 | + wasmProd(inputId, reduceSize, CppDType[out.dtype], outId); |
| 77 | + } |
| 78 | + |
| 79 | + if (inputWasTransposed) { |
| 80 | + // dispose of the transposed tensor. |
| 81 | + backend.disposeData(transposed.dataId); |
| 82 | + } |
| 83 | + |
| 84 | + if (keepDims) { |
| 85 | + // reshape |
| 86 | + const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes); |
| 87 | + out.shape = newShape; |
| 88 | + } |
| 89 | + |
| 90 | + return out; |
| 91 | +} |
| 92 | + |
| 93 | +export const prodConfig: KernelConfig = { |
| 94 | + kernelName: Prod, |
| 95 | + backendName: 'wasm', |
| 96 | + setupFunc: setup, |
| 97 | + kernelFunc: prod as {} as KernelFunc |
| 98 | +}; |
0 commit comments