Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tfjs-layers/src/engine/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ export abstract class Container extends Layer {
* subclasses: ['LayersModel']
* }
*/
getLayer(name?: string, index?: number): Layer {
getLayer(nameOrIndex?: string|number, index?: number): Layer {
if (index != null) {
if (this.layers.length <= index) {
throw new ValueError(
Expand All @@ -982,17 +982,17 @@ export abstract class Container extends Layer {
return this.layers[index];
}
} else {
if (name == null) {
if (nameOrIndex == null) {
throw new ValueError('Provide either a layer name or layer index');
}
}

for (const layer of this.layers) {
if (layer.name === name) {
if (layer.name === nameOrIndex) {
return layer;
}
}
throw new ValueError(`No such layer: ${name}`);
throw new ValueError(`No such layer: ${nameOrIndex}`);
}

/**
Expand Down
8 changes: 4 additions & 4 deletions tfjs-layers/src/engine/container_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ describeMathCPUAndGPU('Container', () => {

it('getLayer works by index', () => {
const [container, layers] = createSimpleTwoLayerContainer();
expect(container.getLayer(null, 0)).toEqual(layers[0]);
expect(container.getLayer(null, 1)).toEqual(layers[1]);
expect(container.getLayer(null, 2)).toEqual(layers[2]);
expect(container.getLayer(0)).toEqual(layers[0]);
expect(container.getLayer(1)).toEqual(layers[1]);
expect(container.getLayer(2)).toEqual(layers[2]);
});

it('getLayer throws error for nonexistent layer name', () => {
Expand All @@ -400,7 +400,7 @@ describeMathCPUAndGPU('Container', () => {

it('getLayer throws error for index out of bound', () => {
const container = createSimpleTwoLayerContainer()[0];
expect(() => container.getLayer(null, 3)).toThrowError(/only has 3 layer/);
expect(() => container.getLayer(3)).toThrowError(/only has 3 layer/);
});

it('getLayer throws error when neither name or index is specified', () => {
Expand Down
30 changes: 14 additions & 16 deletions tfjs-layers/src/models_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,13 @@ describeMathCPU('Nested model topology', () => {
{units: 1, kernelInitializer: 'ones', biasInitializer: 'ones'})
]
});
expect(outerModel.getLayer(null, 0) instanceof tfl.Sequential)
.toEqual(true);
expect(outerModel.getLayer(0) instanceof tfl.Sequential).toEqual(true);
// Expect all-one values based on the kernel and bias initializers specified
// above.
expectTensorsClose(
outerModel.getLayer(null, 1).getWeights()[0], ones([2, 1]));
expectTensorsClose(outerModel.getLayer(null, 1).getWeights()[1], ones([1]));
expectTensorsClose(outerModel.getLayer(1).getWeights()[0], ones([2, 1]));
expectTensorsClose(outerModel.getLayer(1).getWeights()[1], ones([1]));
// Expect there to be only two layers.
expect(() => outerModel.getLayer(null, 2)).toThrow();
expect(() => outerModel.getLayer(2)).toThrow();
});

it('getWeights() works for nested sequential model', () => {
Expand Down Expand Up @@ -1080,8 +1078,8 @@ describeMathCPU('loadLayersModel from URL', () => {
});
});

const model = await loadLayersModel(
io.browserHTTPRequest('model/model.json', {
const model =
await loadLayersModel(io.browserHTTPRequest('model/model.json', {
requestInit: {
headers: {'header_key_1': 'header_value_1'},
credentials: 'include',
Expand Down Expand Up @@ -1294,7 +1292,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 * 8 + 4 * 1 + 4);
.toEqual(4 * 8 + 4 * 1 + 4);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down Expand Up @@ -1353,8 +1351,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// The second part comes from the bias of the dense layer, which has 1
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData)
.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down Expand Up @@ -1414,7 +1412,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down Expand Up @@ -1472,7 +1470,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2);
.toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down Expand Up @@ -1535,7 +1533,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down Expand Up @@ -1594,7 +1592,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
.toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down Expand Up @@ -1652,7 +1650,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => {
// element and is also 4 bytes.
const weightData = savedArtifacts.weightData;
expect(new io.CompositeArrayBuffer(weightData).byteLength)
.toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2);
.toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2);

// Load the model back, with the optimizer.
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
Expand Down