Skip to content

Commit 9f067bd

Browse files
authored
add kernelNames getter to profile object (#4144)
FEATURE
1 parent cfa756f commit 9f067bd

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

tfjs-core/src/engine.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export type ProfileInfo = {
6666
newBytes: number; newTensors: number; peakBytes: number;
6767
kernels: KernelInfo[];
6868
result: TensorContainer;
69+
kernelNames: string[];
6970
};
7071

7172
export interface TimingInfo extends BackendTimingInfo {
@@ -119,8 +120,16 @@ class EngineState {
119120
}>();
120121

121122
profiling = false;
122-
activeProfile: ProfileInfo =
123-
{newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
123+
activeProfile: ProfileInfo = {
124+
newBytes: 0,
125+
newTensors: 0,
126+
peakBytes: 0,
127+
kernels: [],
128+
result: null,
129+
get kernelNames() {
130+
return Array.from(new Set(this.kernels.map(k => k.name)));
131+
}
132+
};
124133

125134
dispose() {
126135
for (const variableName in this.registeredVariables) {

tfjs-core/src/engine_test.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,19 @@ describeWithFlags('profile', ALL_ENVS, () => {
482482
'extraInfo': profile.kernels[0].extraInfo
483483
});
484484
});
485+
486+
it('reports correct kernelNames', async () => {
487+
const profile = await tf.profile(() => {
488+
const x = tf.tensor1d([1, 2, 3]);
489+
const x2 = x.square();
490+
const x3 = x2.abs();
491+
return x3;
492+
});
493+
494+
expect(profile.kernelNames).toEqual(jasmine.arrayWithExactContents([
495+
'Square', 'Abs'
496+
]));
497+
});
485498
});
486499

487500
describeWithFlags('disposeVariables', ALL_ENVS, () => {

tfjs-core/src/globals.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ export function memory(): MemoryInfo {
120120
* - `kernels`: an array of objects for each kernel involved that reports
121121
* their input and output shapes, number of bytes used, and number of new
122122
* tensors created.
123+
* - `kernelNames`: an array of unique strings with just the names of the
124+
* kernels in the `kernels` array.
123125
*
124126
* ```js
125127
* const profile = await tf.profile(() => {

0 commit comments

Comments
 (0)