Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3a655f4
Add Tokenizer base class.
pforderique Jun 14, 2023
85e34c3
Update licence to 2023
pforderique Jun 14, 2023
8121219
Fix lint errors.
pforderique Jun 14, 2023
7b9c1f2
Only expose WhiteSpaceTokenizer in tests
pforderique Jun 15, 2023
cd2fc6e
Rename WhitespaceTokenizer to SimpleTokenizer
pforderique Jun 15, 2023
22e7b81
Add example to Tokenizers docstring.
pforderique Jun 15, 2023
992fb99
BytePairEncoding implementation started.
pforderique Jun 15, 2023
2f3e944
Update test name to Tokenizer
pforderique Jun 15, 2023
cf8c974
Destructure TokenizerOptions to assign default mode value
pforderique Jun 15, 2023
9eb393a
Use destructured mode in call method
pforderique Jun 15, 2023
06b4f26
Wrap tokenizer in beforeEach clause.
pforderique Jun 15, 2023
b4912ab
Make TokenizerOptions optional in call()
pforderique Jun 15, 2023
d805bb8
Add utils file for tokenizers.
pforderique Jun 15, 2023
9a3ef53
Update example in Tokenizers to print output.
pforderique Jun 15, 2023
2bd1943
Don't register test SimpleTokenizer
pforderique Jun 15, 2023
af76234
Bring in changes from Tokenizer
pforderique Jun 15, 2023
a00f45c
Add bytesToUnicode and HashTable for tensors.
pforderique Jun 16, 2023
b792a41
Bring in changes from main
pforderique Jun 16, 2023
e5b9a59
Merge in tokenizers_utils file
pforderique Jun 16, 2023
0a83a76
Add utils for BPE
pforderique Jun 16, 2023
9b909c2
Change test name to createStaticHashtable
pforderique Jun 16, 2023
836a5be
Add BytePairTokenizer Cache
pforderique Jun 16, 2023
f1a72d9
Add tests for BytePairTokenizerCache
pforderique Jun 16, 2023
ebed82c
Move BytePairTokenizerCache to correct location
pforderique Jun 16, 2023
509de9b
Add removeStringsFromInputs
pforderique Jun 16, 2023
3aebe00
Fix test case for removes strings successfully
pforderique Jun 16, 2023
48fd42b
Add createAltsForUnsplittableTokens.
pforderique Jun 16, 2023
a52b406
Fix regex pattern in createAlts
pforderique Jun 16, 2023
72b1836
Fix test case for createAlt
pforderique Jun 16, 2023
8819bdd
Switch to using await data() rather than dataSync().
pforderique Jun 16, 2023
b4b316d
Fix lint errors.
pforderique Jun 16, 2023
c28bb38
Merge branch 'orderique' into all-bpe-utils
pforderique Jun 16, 2023
e11a6a8
Remove dataSyncs().
pforderique Jun 16, 2023
a17937d
Switch to using tensor instead of tensor1d
pforderique Jun 16, 2023
9e77b5d
Add whitespace Regex strings
pforderique Jun 20, 2023
4817427
Add regexSplit
pforderique Jun 20, 2023
f902617
Fix removeStringsFromInputs
pforderique Jun 20, 2023
3a81199
Implement and fix regexSplit
pforderique Jun 21, 2023
57637b0
splitstringsforbpe progress
pforderique Jun 21, 2023
30a1ffd
splitStringForBpe passing 1/2 tests
pforderique Jun 21, 2023
6b6a348
Add mergeLastTwoDims
pforderique Jun 21, 2023
b60b52f
Implement splitStringsForBpe and add tests.
pforderique Jun 21, 2023
18018d8
Merge branch 'main' into all-bpe-utils
pforderique Jun 21, 2023
39de746
Replace lookahead regex ude to Safari not supporting it
pforderique Jun 21, 2023
3bde5d3
address comments
pforderique Jun 23, 2023
ed3c41a
Merge branch 'main' into all-bpe-utils
pforderique Jun 26, 2023
911b41e
Use polyfill matchAll instead
pforderique Jun 26, 2023
d6b639d
Merge branch 'main' into all-bpe-utils
pforderique Jun 27, 2023
1b28b0e
Utilize tensor transforms
pforderique Jun 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions tfjs-layers/src/layers/nlp/tokenizers_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import { Tensor, tensor } from '@tensorflow/tfjs-core';
import { ValueError } from '../../errors';
import { matchAll } from './match_all_polyfill';

export function bytesToUnicode(): [Uint8Array, string[]] {
const inclusiveRange = (start: number, end: number) =>
Expand Down Expand Up @@ -141,3 +142,163 @@ export class BytePairTokenizerCache {
return arrKeys.map(key => this._cache.get(key));
}
}

/**
* Remove certain strings from input tensor.
*/
export async function removeStringsFromInputs(
inputs: Tensor[], stringToRemove: string): Promise<Tensor[]> {

const stringArrInputs = await tensorArrToString2DArr(inputs);

const filteredStrArrays = stringArrInputs
.map(arr => arr.filter(s => s !== stringToRemove))
.filter(arr => arr.length > 0);

const filteredTensors = filteredStrArrays.map(arr => tensor(arr));

return filteredTensors;
}

/**
* Create alternates for all special tokens that will be not split during
* tokenization.
*/
export function createAltsForUnsplittableTokens(
unsplittableTokens: string[]): string[] {

const prefix = 'ĵ';

// Trim out splitters.
const replacePattern: RegExp = /'|\s+|[^\p{L}\p{N}]+/gu;
return unsplittableTokens.map(
token => prefix + token.replace(replacePattern, ''));
}

// Typescript and TF handles special spaces differently, we need to
// manually handle special spaces during string split.
const SPECIAL_WHITESPACES = /\u00A0\u2009\u202f\u3000/;

// String splitting regex pattern.
const pL = 'a-zA-ZáàâäãåçéèêëíìîïñóòôöõúùûüýÿæœÁÀÂÄÃÅÇÉÈÊËÍÌÎÏÑÓÒÔÖÕÚÙÛÜÝŸÆŒĵ';
const pN = '0-9';
export const SPLIT_PATTERN_1 = new RegExp(
`'s|'t|'re|'ve|'m|'ll|'d` +
`|[\\s${SPECIAL_WHITESPACES.source}]+` +
`[\\n\\r\\t\\f६${SPECIAL_WHITESPACES.source}]| ?${pL}+|`+
` ?${pN}+| ?[^\\s${pL}${pN}${SPECIAL_WHITESPACES.source}]+`,
'gu'
);

const SPLIT_PATTERN_2 = new RegExp(`[\\s६${SPECIAL_WHITESPACES.source}]\$`);

function flatten<T>(inputs: T[][]): T[] {
return inputs.reduce(
(accumulator, value) => accumulator.concat(value), []);
}

export function regexSplit(
strs: string[]|string[][],
delimRegexPattern: RegExp | string,
keepDelimRegexPattern = false): string[][] {

if (strs[0] instanceof Array) {
const mapped = strs.map(arr => regexSplit(
arr as string[], delimRegexPattern, keepDelimRegexPattern));
return mapped.map(flatten);
}

strs = strs as string[];

if (!(delimRegexPattern instanceof RegExp)) {
if (keepDelimRegexPattern) {
delimRegexPattern = new RegExp(`(${delimRegexPattern})`);
}
return strs.map(str => str.split(delimRegexPattern).filter(s => s));
}

const regexPattern = delimRegexPattern.flags.includes('g') ?
delimRegexPattern
: new RegExp(delimRegexPattern.source, delimRegexPattern.flags + 'g');

return strs.map(str => {
const matches = matchAll(str, regexPattern);

const splitString = [];
let currIdx = 0;
for (const match of matches) {
splitString.push(str.slice(currIdx, match.index));
if (keepDelimRegexPattern) {
splitString.push(
str.slice(match.index, match.index! + match[0].length));
}
currIdx = match.index! + match[0].length;
}
if (currIdx !== str.length) {
splitString.push(str.slice(currIdx, str.length));
}
return splitString.filter(s => s);
});
}

export async function tensorToStringArr(input: Tensor): Promise<string[]> {
return await input.data() as unknown as string[];
}

export async function tensorArrToString2DArr(
inputs: Tensor[]): Promise<string[][]> {
return Promise.all(inputs.map(
async input => tensorToStringArr(input)));
}

export async function splitStringsForBpe(
inputs: Tensor, unsplittableTokens?: string[]): Promise<Tensor[]> {

// We need to recreate the exact behavior of token presplitting in the
// original gpt2 implementation which uses a lookahead. We are using an
// alternative by inserting a special token "६" before leading space of
// non-space characters and after the trailing space, e.g.,
// " tf" will be "६ tf".
const pattern1 = new RegExp(`( )([^\s${SPECIAL_WHITESPACES}])`);
const pattern2 = new RegExp(`(\s${SPECIAL_WHITESPACES})\$`);

const inputsStr = (await inputs.data() as unknown as string[]).map(str =>
str.replace(pattern1, `६$1$2`).replace(pattern2, `$1६`)
);

let alts: string[];
let rawTokens: string[][];

function escape(input: string): string {
return input.replace(/[-\/\\^$*+?.()|[\]{}]/g, '\\$&');
}

if (unsplittableTokens && unsplittableTokens.length > 0) {
alts = createAltsForUnsplittableTokens(unsplittableTokens);
for (const [idx, token] of unsplittableTokens.entries()) {
const alt = alts[idx];
const escapedToken = escape(token);

rawTokens = regexSplit(rawTokens !== undefined ?
rawTokens : inputsStr, escapedToken, true);
rawTokens = rawTokens.map(
arr => arr.map(t => t.replace(escapedToken, alt)));
}
}
rawTokens = regexSplit(rawTokens !== undefined ?
rawTokens : inputsStr, SPLIT_PATTERN_1, true);
// Second pass splits out the last whilespace char or "६".
rawTokens = regexSplit(rawTokens, SPLIT_PATTERN_2, true);

if (unsplittableTokens && unsplittableTokens.length > 0) {
// Replace special tokens alternate with originals.
for (const [idx, token] of unsplittableTokens.entries()) {
const alt = alts[idx];
const escapedAlt = escape(alt);
rawTokens = rawTokens.map(
arr => arr.map(t => t.replace(escapedAlt, token)));
}
}

return removeStringsFromInputs(rawTokens.map(tokens => tensor(tokens)), '६');
}
122 changes: 114 additions & 8 deletions tfjs-layers/src/layers/nlp/tokenizers_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
* =============================================================================
*/

import { tensor1d, test_util } from '@tensorflow/tfjs-core';
import { tensor, test_util } from '@tensorflow/tfjs-core';

import { BytePairTokenizerCache, bytesToUnicode, createStaticHashtable } from './tokenizers_utils';
import { BytePairTokenizerCache, SPLIT_PATTERN_1, bytesToUnicode,
createAltsForUnsplittableTokens, createStaticHashtable, regexSplit,
removeStringsFromInputs, splitStringsForBpe } from './tokenizers_utils';
import { expectTensorsClose } from '../../utils/test_utils';

describe('bytesToUnicode', () => {
Expand Down Expand Up @@ -74,8 +76,8 @@ describe('createStaticHashtable', () => {
expect(byte2Unicode.get(-1)).toBe('');

expectTensorsClose(
(await byte2Unicode.lookup([tensor1d([33, 133])]))[0],
tensor1d(['!', '\x85']));
(await byte2Unicode.lookup([tensor([33, 133])]))[0],
tensor(['!', '\x85']));
});

it('creates StaticHashTable<string, number> correctly', async () => {
Expand All @@ -87,8 +89,8 @@ describe('createStaticHashtable', () => {
expect(unicode2Byte.get('😁')).toBe(-1);

expectTensorsClose(
(await unicode2Byte.lookup([tensor1d(['!', '{'])]))[0],
tensor1d([33, 123]));
(await unicode2Byte.lookup([tensor(['!', '{'])]))[0],
tensor([33, 123]));
});
});

Expand All @@ -109,9 +111,113 @@ describe('BytePairTokenizerCache', () => {

it('inserts tensors and retrieves correctly', async () => {
await cache.insert(
tensor1d(['butterfly', 'dragonfly']), ['but ter fly', 'dragon fly']);
tensor(['butterfly', 'dragonfly']), ['but ter fly', 'dragon fly']);

test_util.expectArraysEqual(
await cache.lookup(tensor1d(['dragonfly'])), ['dragon fly']);
await cache.lookup(tensor(['dragonfly'])), ['dragon fly']);
});
});

describe('removeStringsFromInputs', () => {
it ('removes nothing successfully', async () => {
const inputs = [tensor(['butterfly']), tensor(['butter'])];
const stringToRemove = '६';

const result = await removeStringsFromInputs(inputs, stringToRemove);

expect(result.length).toBe(2);
expectTensorsClose(result[0], tensor(['butterfly']));
expectTensorsClose(result[1], tensor(['butter']));
});

it ('removes strings successfully', async () => {
const inputs = [tensor(['butterfly']), tensor(['butter'])];
const stringToRemove = 'butter';

const result = await removeStringsFromInputs(inputs, stringToRemove);

expect(result.length).toBe(1);
expectTensorsClose(result[0], tensor(['butterfly']));
});
});

describe('createAltsForUnsplittableTokens', () => {
it ('creates alts with no matching regex pattern', () => {
const unsplittableTokens = ['s', 'p'];

const result = createAltsForUnsplittableTokens(unsplittableTokens);

expect(result.length).toBe(2);
test_util.expectArraysEqual(result, ['ĵs', 'ĵp']);
});

it ('creates alts with matching regex pattern', () => {
const unsplittableTokens = [' s', 'p'];

const result = createAltsForUnsplittableTokens(unsplittableTokens);

expect(result.length).toBe(2);
test_util.expectArraysEqual(result, ['ĵs', 'ĵp']);
});

it ('regex works correctly', () => {
const unsplittableTokens = ['😊,_五$ñü]aA5{\'\n~`'];

const result = createAltsForUnsplittableTokens(unsplittableTokens);

expect(result.length).toBe(1);
test_util.expectArraysEqual(result, ['ĵ五ñüaA5']);
});
});

describe('regexSplit', () => {
it ('splits with regex and string', () => {
const strResult = regexSplit(['hello there'], /\s/g);
const regexResult = regexSplit(['hello there'], ' ');
const expected = [['hello', 'there']];

test_util.expectArraysEqual(strResult, expected);
test_util.expectArraysEqual(regexResult, expected);
});

it ('keeps string delimiter', () => {
test_util.expectArraysEqual(regexSplit(['sp'], 's', true), [['s', 'p']]);
test_util.expectArraysEqual(
regexSplit(['\xc4\xb4s', 'p'], 'p', true), [['Ä´s'], ['p']] );
});

it('splits regex delimiter', () => {
const result = regexSplit(['ĵs', 'ĵp'], SPLIT_PATTERN_1, true);

test_util.expectArraysEqual(result, [['ĵs'], ['ĵp']]);
});

it('works with periods', () => {
const result = regexSplit(
['brown.', 'black.'], SPLIT_PATTERN_1, true);

test_util.expectArraysEqual(result, [['brown', '.'], ['black', '.']]);
});
});

describe('splitStringsForBpe', () => {
it ('splits with unsplittable tokens', async () => {
const inputs = tensor(['sp']);
const unsplittableTokens = ['s', 'p'];

const result = await splitStringsForBpe(inputs, unsplittableTokens);

expect(result.length).toBe(1);
expectTensorsClose(result[0], tensor(['s', 'p']));
});

it ('splits with no unsplittable tokens', async () => {
const inputs = tensor(['brown.', 'black.']);

const result = await splitStringsForBpe(inputs);

expect(result.length).toBe(2);
expectTensorsClose(result[0], tensor(['brown', '.']));
expectTensorsClose(result[1], tensor(['black', '.']));
});
});