Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tfjs-backend-wasm/src/cc/kernels/ResizeBilinear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ EMSCRIPTEN_KEEPALIVE

void ResizeBilinear(size_t x_id, size_t batch, size_t old_height,
size_t old_width, size_t num_channels, size_t new_height,
size_t new_width, bool align_corners, size_t out_id) {
size_t new_width, bool align_corners,
bool half_pixel_centers, size_t out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info_out(out_id);

Expand All @@ -59,8 +60,12 @@ void ResizeBilinear(size_t x_id, size_t batch, size_t old_height,

xnn_operator_t resize_bilinear_op = nullptr;

const uint32_t flags = XNN_FLAG_TENSORFLOW_LEGACY_MODE |
(align_corners ? XNN_FLAG_ALIGN_CORNERS : 0);
uint32_t flags = 0;
if (align_corners) {
flags |= XNN_FLAG_ALIGN_CORNERS;
} else if (!half_pixel_centers) {
flags |= XNN_FLAG_TENSORFLOW_LEGACY_MODE;
}

OperatorCacheKey cache_key = {num_channels, flags};

Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-wasm/src/cc/kernels/ResizeBilinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ extern "C" {

void ResizeBilinear(size_t x_id, size_t batch, size_t old_height,
size_t old_width, size_t num_channels, size_t new_height,
size_t new_width, bool align_corners, size_t out_id);
size_t new_width, bool align_corners,
bool half_pixel_centers, size_t out_id);
}

} // namespace wasm
Expand Down
11 changes: 6 additions & 5 deletions tfjs-backend-wasm/src/cc/kernels/ResizeBilinear_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,24 @@ TEST(BATCH_MATMUL, xnn_operator_lifetime) {
const size_t new_height0 = 2;
const size_t new_width0 = 2;
bool align_corners0 = false;
bool half_pixel_centers0 = false;
tfjs::wasm::ResizeBilinear(x0_id, batch_size, old_height0, old_width0,
num_channels0, new_height0, new_width0,
align_corners0, out_id);
align_corners0, half_pixel_centers0, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// No new xnn_operators should be created for the second call to
// ResizeBilinear with the same arguments.
tfjs::wasm::ResizeBilinear(x0_id, batch_size, old_height0, old_width0,
num_channels0, new_height0, new_width0,
align_corners0, out_id);
align_corners0, half_pixel_centers0, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// No new xnn_operators should be created for the second call to
// ResizeBilinear with a new x id but same arguments.
tfjs::wasm::ResizeBilinear(x1_id, batch_size, old_height0, old_width0,
num_channels0, new_height0, new_width0,
align_corners0, out_id);
align_corners0, half_pixel_centers0, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// One new xnn_operator should be created for another call to ResizeBilinear
Expand All @@ -78,15 +79,15 @@ TEST(BATCH_MATMUL, xnn_operator_lifetime) {
const size_t new_width1 = 4;
tfjs::wasm::ResizeBilinear(x0_id, batch_size, old_height1, old_width1,
num_channels1, new_height1, new_width1,
align_corners0, out_id);
align_corners0, half_pixel_centers0, out_id);
ASSERT_EQ(2, tfjs::backend::xnn_operator_count);

// One new xnn_operator should be created for another call to ResizeBilinear
// with a different align_corners argument
bool align_corners1 = true;
tfjs::wasm::ResizeBilinear(x0_id, batch_size, old_height1, old_width1,
num_channels1, new_height1, new_width1,
align_corners1, out_id);
align_corners1, half_pixel_centers0, out_id);
ASSERT_EQ(3, tfjs::backend::xnn_operator_count);

tfjs::wasm::dispose();
Expand Down
11 changes: 3 additions & 8 deletions tfjs-backend-wasm/src/kernels/ResizeBilinear.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {cast} from './Cast';
let wasmResizeBilinear: (
xId: number, batch: number, oldHeight: number, oldWidth: number,
numChannels: number, newHeight: number, newWidth: number,
alignCorners: number, outId: number) => void;
alignCorners: number, halfPixelCenters: number, outId: number) => void;

function setup(backend: BackendWasm): void {
wasmResizeBilinear = backend.wasm.cwrap(ResizeBilinear, null /*void*/, [
Expand All @@ -36,6 +36,7 @@ function setup(backend: BackendWasm): void {
'number', // newHeight
'number', // newWidth
'number', // alignCorners
'number', // halfPixelCenters
'number' // outId
]);
}
Expand All @@ -54,12 +55,6 @@ function resizeBilinear(args: {
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
const outShape = [batch, newHeight, newWidth, numChannels];

if (halfPixelCenters) {
throw new Error(
`The wasm resizeBilinear kernel does not yet support ` +
`halfPixelCenters being true.`);
}

let xData = backend.dataIdMap.get(images.dataId);
let castedData;
if (xData.dtype !== 'float32') {
Expand All @@ -77,7 +72,7 @@ function resizeBilinear(args: {

wasmResizeBilinear(
xId, batch, oldHeight, oldWidth, numChannels, newHeight, newWidth,
alignCorners ? 1 : 0, outId);
alignCorners ? 1 : 0, halfPixelCenters ? 1 : 0, outId);

if (castedData != null) {
backend.disposeData(castedData.dataId);
Expand Down
3 changes: 1 addition & 2 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ const TEST_FILTERS: TestFilter[] = [
{
include: 'resizeBilinear',
excludes: [
'gradients', // Not yet implemented.
'halfPixelCenters' // Not yet implemented.
'gradients', // Not yet implemented.
]
},
{
Expand Down