Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import { ContextToken } from "./context-token.js";
import { CorrectionSearchable, PathResult } from "./correction-searchable.js";
import { ContextTokenization } from "./context-tokenization.js";
import { QuotientNodeFinalizer } from "./quotient-node-finalizer.js";
import { TokenizationResultMapping } from "./tokenization-result-mapping.js";
import { TokenizationResult, TokenizationResultMapping } from "./tokenization-result-mapping.js";
import { EDIT_DISTANCE_COST_SCALE } from "./distance-modeler.js";
import { MAX_EDIT_THRESHOLD_FACTOR } from "./search-quotient-spur.js";
import { TokenResultMapping } from "./token-result-mapping.js";

// PathResult needs to be generic:
// - a result for correcting a single Token - "TokenResult"?
Expand Down Expand Up @@ -46,7 +47,7 @@ export type TokenResult = {
* all correctable tokens, generating corrections for the full represented
* range.
*/
export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray<TokenResult>, TokenizationResultMapping> {
export class TokenizationCorrector implements CorrectionSearchable<TokenizationResult, TokenizationResultMapping> {
public readonly tokenization: ContextTokenization;
private readonly tailCorrectionLength: number;

Expand All @@ -65,6 +66,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
private lastTotalCost: number;
private handleHasBeenCalled: boolean = false;
private predictableMatchFound: boolean = false;
private matchableTokenCount = 0;

get currentCost(): number {
const correctable = this.selectionQueue.peek();
Expand Down Expand Up @@ -156,7 +158,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
constructor(
tokenization: ContextTokenization,
tailCorrectionLength: number,
filterClosure: (token: ContextToken) => boolean
filterClosure: (token: ContextToken, index?: number) => boolean
) {
this.tokenization = tokenization;
this.tailCorrectionLength = tailCorrectionLength;
Expand All @@ -175,16 +177,22 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.tokenLookupMap = new Map();
let modelsCorrectables = false;

// 0 index: the first index in range to be modeled, as split off from the main tokenization.
orderedTokens.forEach((token, index) => {
// New issue: this mangles the space IDs! We almost certainly need some
// sort of proper map to the source token.
const searchModule = new QuotientNodeFinalizer(token.searchModule, index == orderedTokens.length - 1);
this.tokenLookupMap.set(searchModule.spaceId, token);
const passesFilter = filterClosure(token);
// Index within the token subset being examined.
const passesFilter = filterClosure(token, index);
modelsCorrectables ||= passesFilter;
if(!passesFilter) {
this._uncorrectables.push(searchModule);
} else if(index == tailCorrectionLength - 1) {
return;
}

this.matchableTokenCount++;
if(index == tailCorrectionLength - 1) {
// The sole assignment case for this field. It may only be assigned for
// the final token, and only if its text is of a form considered
// correctable by the filter.
Expand Down Expand Up @@ -249,6 +257,10 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
return new TokenizationResultMapping(results, this);
}

private get matchedTokenCount() {
return [...this._generatedTokenResults.values()].filter((r) => r instanceof TokenResultMapping).length;
}

// The actual method used to iteratively search for tokenization-level corrections.
handleNextNode(): PathResult<TokenizationResultMapping> {
// Notable states:
Expand All @@ -272,19 +284,24 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.handleHasBeenCalled = true;
const results = this.collateResults();
this._previousResults.push(results);
return {
'type': 'complete',
cost: this.lastTotalCost,
mapping: results
};

// If no matchables exist, there's no prediction to do; don't make a return.
if(this.matchedTokenCount > 0) {
return {
'type': 'complete',
cost: this.lastTotalCost,
mapping: results
};
} else {
return { type: 'none' };
}
}
}

this.handleHasBeenCalled = true;

const correctableToUpdate = this.selectionQueue.dequeue();
const tokenResult = correctableToUpdate?.handleNextNode();

const delistCorrectable = () => {
if(correctableToUpdate != this._predictable) {
// Lock the 'correctable' token now that either a valid correction for
Expand All @@ -298,6 +315,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
}

const correctionIsThePredictable = correctableToUpdate == this._predictable;

if(tokenResult.type == 'none') {
if(!correctionIsThePredictable || !this.predictableMatchFound) {
// Transition the node from 'correctable' to 'uncorrectable' - we were
Expand Down Expand Up @@ -359,25 +377,33 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.selectionQueue.enqueue(this._predictable);
}

const correctionResults = this.collateResults();
if(correctionResults.matchedResult.findIndex((c) => c == undefined) != -1) {
// If any token lacks a matching lookup value, abort.
if([...this.tokenLookupMap.keys()].find((k) => !this._generatedTokenResults.has(k))) {
return {
type: 'intermediate',
cost: tokenizationCost
};
}
const correctionResults = this.collateResults();

// Determine the proper return type and construct the proper return object accordingly.
//
// If there was no result obtained from the predictable and a result was previously found,
// that indicates no further predictions may be found.
if(tokenResult.type != 'none' || !correctionIsThePredictable || !this.predictableMatchFound) {
this._previousResults.push(correctionResults);
return {
type: 'complete',
cost: tokenizationCost,
mapping: correctionResults
};

if(this.matchedTokenCount > 0) {
return {
type: 'complete',
cost: tokenizationCost,
mapping: correctionResults
};
} else {
return {
type: 'none'
}
}
} else {
return {
type: 'none'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import { CorrectionResultMapping } from "./correction-result-mapping.js";
import { TokenizationCorrector, TokenResult } from './tokenization-corrector.js';

export class TokenizationResultMapping implements CorrectionResultMapping<ReadonlyArray<TokenResult>> {
export interface TokenizationResult {
tokenCorrections: ReadonlyArray<TokenResult>,
totalEditCount: number,
totalEditableCodepoints: number
}

export class TokenizationResultMapping implements CorrectionResultMapping<TokenizationResult> {
readonly matchingSpace: TokenizationCorrector;
readonly matchedResult: ReadonlyArray<TokenResult>;
readonly matchedResult: TokenizationResult;

constructor(tokenization: TokenResult[], corrector: TokenizationCorrector) {
this.matchingSpace = corrector;
this.matchedResult = tokenization;
this.matchedResult = {
tokenCorrections: tokenization,
totalEditCount: tokenization.reduce((accum, curr) => accum + curr.knownCost, 0),
totalEditableCodepoints: 0 //corrector.
}
}

get spaceId(): number {
Expand All @@ -22,7 +32,7 @@ export class TokenizationResultMapping implements CorrectionResultMapping<Readon
// * `totalCost`.)
// */
// get knownCost(): number {
// return this.node.editCount;
// return this.matchedResult.tokenCorrections.reduce((accum, curr) => accum + curr.knownCost, 0);
// }

// /**
Expand All @@ -40,6 +50,6 @@ export class TokenizationResultMapping implements CorrectionResultMapping<Readon
* to the resulting output.
*/
get totalCost(): number {
return this.matchedResult.reduce((total, curr) => total + curr.totalCost, 0);
return this.matchedResult.tokenCorrections.reduce((total, curr) => total + curr.totalCost, 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,8 @@ export function buildCorrectionSequence(
const orderedTokens = tokenizationCorrection.matchingSpace?.orderedTokens;
const tokens: PredictionParameters['tokens'] = [];

for(let i = 0; i < tokenizationCorrection.matchedResult.length; i++) {
const correction = tokenizationCorrection.matchedResult[i];
for(let i = 0; i < tokenizationCorrection.matchedResult.tokenCorrections.length; i++) {
const correction = tokenizationCorrection.matchedResult.tokenCorrections[i];
/* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost
* the exponent to ensure only VERY nearby corrections have a chance of winning, and only if
* there are significantly more likely words. We only need this to allow very minor fat-finger
Expand Down Expand Up @@ -823,9 +823,18 @@ export function predictFromCorrectionSequence(

const predictionComponents = correctionTokens.map((correctionToken, i) => {
const correctionTransform = correctionToken.correction.sample;
const predictions = lexicalModel.predict(correctionTransform, currentContext);
let predictions = lexicalModel.predict(correctionTransform, currentContext);
const transitionId = correctionTransform.id;

// Ensure codepointLength == prediction codepoint length if i does not match the tail!
// Filter out cases that do not conform to this condition.
if(i != correctionTokens.length - 1) {
predictions = predictions.filter((p) => {
const codepointLength = KMWString.length(correctionToken.correction.sample.insert);
return KMWString.length(p.sample.transform.insert) == codepointLength;
});
}

// Failsafe: if there are no matching predictions, create a fake prediction
// matching the original text.
if(predictions.length != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import {
SubstitutionQuotientSpur,
TokenizationCorrector,
TokenResult,
TokenizationResultMapping
TokenizationResultMapping,
TokenizationResult
} from '@keymanapp/lm-worker/test-index';

import Distribution = LexicalModelTypes.Distribution;
Expand Down Expand Up @@ -302,7 +303,7 @@ describe('TokenizationCorrector', () => {
assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
const tokenResults = mapping.matchedResult.tokenCorrections;
assert.isNotNaN(searchResult.cost);
assert.equal(searchResult.cost, searchResult.mapping.totalCost);
assert.equal(tokenResults.length, 1);
Expand All @@ -327,7 +328,7 @@ describe('TokenizationCorrector', () => {
assert.equal(searchResult.type, 'none');
});

it('finds a default correction for a single correctable token without a model match', () => {
it('returns no result when a single correctable token lacks a model match', () => {
const fixture = buildFixture_therefore();

const theref = fixture.theref.tail;
Expand Down Expand Up @@ -371,23 +372,6 @@ describe('TokenizationCorrector', () => {
searchResult = instance.handleNextNode();
} while(searchResult.type == 'intermediate');

assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
assert.isNotNaN(searchResult.cost);
assert.equal(searchResult.cost, searchResult.mapping.totalCost);
assert.equal(tokenResults.length, 1);
assert.sameOrderedMembers(tokenResults.map((r) => r.matchString), ['therefxyz']);

// Now that an entry has been found, verify the corrector's state.
assert.isNotOk(instance.predictableToken); // should become an uncorrectable.
assert.isTrue(instance.generatedTokenResults.has(therefxyz));
assert.equal(instance.generatedTokenResults.get(therefxyz), tokenResults[0]);
}

// There should be no further possible suggestions.
searchResult = instance.handleNextNode();
assert.equal(searchResult.type, 'none');
});

Expand All @@ -411,7 +395,7 @@ describe('TokenizationCorrector', () => {
let firstResults: ReadonlyArray<TokenResult>;
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
const tokenResults = mapping.matchedResult.tokenCorrections;
firstResults = tokenResults;
assert.isNotNaN(searchResult.cost);
assert.equal(searchResult.cost, searchResult.mapping.totalCost);
Expand All @@ -434,7 +418,7 @@ describe('TokenizationCorrector', () => {
searchResult = instance.handleNextNode();
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
const tokenResults = mapping.matchedResult.tokenCorrections;

// Verify that the first (bound) token is not altered further.
// It should receive no further correction attempts.
Expand All @@ -445,7 +429,7 @@ describe('TokenizationCorrector', () => {
} while(searchResult.type != 'none');
});

it('immediately returns a single result when the only represented token is uncorrectable', () => {
it('immediately returns with no result when the only represented token is uncorrectable', () => {
const fixture = buildFixture_terminalWhitespace();

const tokenization = fixture.spaceOnly;
Expand All @@ -457,13 +441,7 @@ describe('TokenizationCorrector', () => {
);

const searchResult = instance.handleNextNode();
assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
assert.equal(searchResult.mapping.matchedResult[0].matchString, ' ');
}

const nilResult = instance.handleNextNode();
assert.equal(nilResult.type, 'none');
assert.equal(searchResult.type, 'none');
});

it('returns a single result when the final token is uncorrectable', () => {
Expand All @@ -484,8 +462,8 @@ describe('TokenizationCorrector', () => {

assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
assert.equal(searchResult.mapping.matchedResult[0].matchString, 'space');
assert.equal(searchResult.mapping.matchedResult[1].matchString, ' ');
assert.equal(searchResult.mapping.matchedResult.tokenCorrections[0].matchString, 'space');
assert.equal(searchResult.mapping.matchedResult.tokenCorrections[1].matchString, ' ');
}

const nilResult = instance.handleNextNode();
Expand All @@ -502,20 +480,20 @@ describe('TokenizationCorrector', () => {
let haveSeenSingleTokenCorrection = false;
let haveSeenThreeTokenCorrection = false;
for await(let phraseMatch of getBestMatches<
ReadonlyArray<TokenResult>,
TokenizationResult,
TokenizationResultMapping,
TokenizationCorrector
>(correctors, buildTestTimer())) {

if(phraseMatch.matchedResult.length == 1) {
if(phraseMatch.matchedResult.tokenCorrections.length == 1) {
if(!haveSeenSingleTokenCorrection) {
assert.sameOrderedMembers(phraseMatch.matchedResult.map((t) => t.matchString), ['theref' /* -ore */]);
assert.sameOrderedMembers(phraseMatch.matchedResult.tokenCorrections.map((t) => t.matchString), ['theref' /* -ore */]);
}

haveSeenSingleTokenCorrection = true;
} else if(phraseMatch.matchedResult.length == 3) {
} else if(phraseMatch.matchedResult.tokenCorrections.length == 3) {
if(!haveSeenThreeTokenCorrection) {
assert.sameOrderedMembers(phraseMatch.matchedResult.map((t) => t.matchString), ['the', ' ', 'ef' /* -fort */]);
assert.sameOrderedMembers(phraseMatch.matchedResult.tokenCorrections.map((t) => t.matchString), ['the', ' ', 'ef' /* -fort */]);
}
haveSeenThreeTokenCorrection = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,20 +542,21 @@ describe('predictFromCorrectionSequence', () => {
];

const expected_prediction_p = dummied_suggestion_sequences
.map((dist) => {
return dist[0]
.map((dist, i) => {
// There is no valid 'g' entry corresponding to token index 0.
return i == 0 ? null : dist[0]
}).reduce((accum, curr) => {
return accum * (curr ? curr.p : Math.exp(-EDIT_DISTANCE_COST_SCALE))
}, 1);

const expected_predictions: Suggestion[] = [
{
transform: {
insert: 'golden',
insert: 'g',
deleteLeft: 0,
id: transitionID
},
displayAs: 'golden',
displayAs: 'g',
transformId: transitionID
}, {
transform: {
Expand Down Expand Up @@ -589,7 +590,7 @@ describe('predictFromCorrectionSequence', () => {
predictions.forEach((entry) => assert.equal(entry.metadata.probabilities.correction, parameters.tokens.reduce((accum, curr) => accum * curr.correction.p, 1)));
predictions.sort(tupleDisplayOrderSort);

assert.deepEqual(predictions[0].components.map((c) => c.prediction.transform.insert), ['golden', ' ', 'apple']);
assert.deepEqual(predictions[0].components.map((c) => c.prediction.transform.insert), ['g', ' ', 'apple']);
assert.sameDeepOrderedMembers(predictions[0].components.map((entry) => entry.prediction), expected_predictions);

assert.approximately(predictions[0].metadata.probabilities.prediction, expected_prediction_p, 0.00001);
Expand Down
Loading