diff --git a/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-corrector.ts b/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-corrector.ts index d5b09f8f487..e449ead8dda 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-corrector.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-corrector.ts @@ -14,9 +14,10 @@ import { ContextToken } from "./context-token.js"; import { CorrectionSearchable, PathResult } from "./correction-searchable.js"; import { ContextTokenization } from "./context-tokenization.js"; import { QuotientNodeFinalizer } from "./quotient-node-finalizer.js"; -import { TokenizationResultMapping } from "./tokenization-result-mapping.js"; +import { TokenizationResult, TokenizationResultMapping } from "./tokenization-result-mapping.js"; import { EDIT_DISTANCE_COST_SCALE } from "./distance-modeler.js"; import { MAX_EDIT_THRESHOLD_FACTOR } from "./search-quotient-spur.js"; +import { TokenResultMapping } from "./token-result-mapping.js"; // PathResult needs to be generic: // - a result for correcting a single Token - "TokenResult"? @@ -46,7 +47,7 @@ export type TokenResult = { * all correctable tokens, generating corrections for the full represented * range. */ -export class TokenizationCorrector implements CorrectionSearchable, TokenizationResultMapping> { +export class TokenizationCorrector implements CorrectionSearchable { public readonly tokenization: ContextTokenization; private readonly tailCorrectionLength: number; @@ -65,6 +66,7 @@ export class TokenizationCorrector implements CorrectionSearchable boolean + filterClosure: (token: ContextToken, index?: number) => boolean ) { this.tokenization = tokenization; this.tailCorrectionLength = tailCorrectionLength; @@ -175,16 +177,22 @@ export class TokenizationCorrector implements CorrectionSearchable { // New issue: this mangles the space IDs! We almost certainly need some // sort of proper map to the source token. const searchModule = new QuotientNodeFinalizer(token.searchModule, index == orderedTokens.length - 1); this.tokenLookupMap.set(searchModule.spaceId, token); - const passesFilter = filterClosure(token); + // Index within the token subset being examined. + const passesFilter = filterClosure(token, index); modelsCorrectables ||= passesFilter; if(!passesFilter) { this._uncorrectables.push(searchModule); - } else if(index == tailCorrectionLength - 1) { + return; + } + + this.matchableTokenCount++; + if(index == tailCorrectionLength - 1) { // The sole assignment case for this field. It may only be assigned for // the final token, and only if its text is of a form considered // correctable by the filter. @@ -249,6 +257,10 @@ export class TokenizationCorrector implements CorrectionSearchable r instanceof TokenResultMapping).length; + } + // The actual method used to iteratively search for tokenization-level corrections. handleNextNode(): PathResult { // Notable states: @@ -272,11 +284,17 @@ export class TokenizationCorrector implements CorrectionSearchable 0) { + return { + 'type': 'complete', + cost: this.lastTotalCost, + mapping: results + }; + } else { + return { type: 'none' }; + } } } @@ -284,7 +302,6 @@ export class TokenizationCorrector implements CorrectionSearchable { if(correctableToUpdate != this._predictable) { // Lock the 'correctable' token now that either a valid correction for @@ -298,6 +315,7 @@ export class TokenizationCorrector implements CorrectionSearchable c == undefined) != -1) { + // If any token lacks a matching lookup value, abort. + if([...this.tokenLookupMap.keys()].find((k) => !this._generatedTokenResults.has(k))) { return { type: 'intermediate', cost: tokenizationCost }; } + const correctionResults = this.collateResults(); // Determine the proper return type and construct the proper return object accordingly. // @@ -373,11 +392,18 @@ export class TokenizationCorrector implements CorrectionSearchable 0) { + return { + type: 'complete', + cost: tokenizationCost, + mapping: correctionResults + }; + } else { + return { + type: 'none' + } + } } else { return { type: 'none' diff --git a/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts b/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts index 32e0fb48fce..c6fd8db93f3 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts @@ -1,13 +1,23 @@ import { CorrectionResultMapping } from "./correction-result-mapping.js"; import { TokenizationCorrector, TokenResult } from './tokenization-corrector.js'; -export class TokenizationResultMapping implements CorrectionResultMapping> { +export interface TokenizationResult { + tokenCorrections: ReadonlyArray, + totalEditCount: number, + totalEditableCodepoints: number +} + +export class TokenizationResultMapping implements CorrectionResultMapping { readonly matchingSpace: TokenizationCorrector; - readonly matchedResult: ReadonlyArray; + readonly matchedResult: TokenizationResult; constructor(tokenization: TokenResult[], corrector: TokenizationCorrector) { this.matchingSpace = corrector; - this.matchedResult = tokenization; + this.matchedResult = { + tokenCorrections: tokenization, + totalEditCount: tokenization.reduce((accum, curr) => accum + curr.knownCost, 0), + totalEditableCodepoints: 0 //corrector. + } } get spaceId(): number { @@ -22,7 +32,7 @@ export class TokenizationResultMapping implements CorrectionResultMapping accum + curr.knownCost, 0); // } // /** @@ -40,6 +50,6 @@ export class TokenizationResultMapping implements CorrectionResultMapping total + curr.totalCost, 0); + return this.matchedResult.tokenCorrections.reduce((total, curr) => total + curr.totalCost, 0); } } \ No newline at end of file diff --git a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts index 936837127ae..6955e669f32 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts @@ -547,8 +547,8 @@ export function buildCorrectionSequence( const orderedTokens = tokenizationCorrection.matchingSpace?.orderedTokens; const tokens: PredictionParameters['tokens'] = []; - for(let i = 0; i < tokenizationCorrection.matchedResult.length; i++) { - const correction = tokenizationCorrection.matchedResult[i]; + for(let i = 0; i < tokenizationCorrection.matchedResult.tokenCorrections.length; i++) { + const correction = tokenizationCorrection.matchedResult.tokenCorrections[i]; /* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost * the exponent to ensure only VERY nearby corrections have a chance of winning, and only if * there are significantly more likely words. We only need this to allow very minor fat-finger @@ -823,9 +823,18 @@ export function predictFromCorrectionSequence( const predictionComponents = correctionTokens.map((correctionToken, i) => { const correctionTransform = correctionToken.correction.sample; - const predictions = lexicalModel.predict(correctionTransform, currentContext); + let predictions = lexicalModel.predict(correctionTransform, currentContext); const transitionId = correctionTransform.id; + // Ensure codepointLength == prediction codepoint length if i does not match the tail! + // Filter out cases that do not conform to this condition. + if(i != correctionTokens.length - 1) { + predictions = predictions.filter((p) => { + const codepointLength = KMWString.length(correctionToken.correction.sample.insert); + return KMWString.length(p.sample.transform.insert) == codepointLength; + }); + } + // Failsafe: if there are no matching predictions, create a fake prediction // matching the original text. if(predictions.length != 0) { diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/tokenization-corrector.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/tokenization-corrector.tests.ts index a53c3a4a4b7..804cf43d40d 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/tokenization-corrector.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/tokenization-corrector.tests.ts @@ -29,7 +29,8 @@ import { SubstitutionQuotientSpur, TokenizationCorrector, TokenResult, - TokenizationResultMapping + TokenizationResultMapping, + TokenizationResult } from '@keymanapp/lm-worker/test-index'; import Distribution = LexicalModelTypes.Distribution; @@ -302,7 +303,7 @@ describe('TokenizationCorrector', () => { assert.equal(searchResult.type, 'complete'); if(searchResult.type == 'complete') { const mapping = searchResult.mapping; - const tokenResults = mapping.matchedResult; + const tokenResults = mapping.matchedResult.tokenCorrections; assert.isNotNaN(searchResult.cost); assert.equal(searchResult.cost, searchResult.mapping.totalCost); assert.equal(tokenResults.length, 1); @@ -327,7 +328,7 @@ describe('TokenizationCorrector', () => { assert.equal(searchResult.type, 'none'); }); - it('finds a default correction for a single correctable token without a model match', () => { + it('returns no result when a single correctable token lacks a model match', () => { const fixture = buildFixture_therefore(); const theref = fixture.theref.tail; @@ -371,23 +372,6 @@ describe('TokenizationCorrector', () => { searchResult = instance.handleNextNode(); } while(searchResult.type == 'intermediate'); - assert.equal(searchResult.type, 'complete'); - if(searchResult.type == 'complete') { - const mapping = searchResult.mapping; - const tokenResults = mapping.matchedResult; - assert.isNotNaN(searchResult.cost); - assert.equal(searchResult.cost, searchResult.mapping.totalCost); - assert.equal(tokenResults.length, 1); - assert.sameOrderedMembers(tokenResults.map((r) => r.matchString), ['therefxyz']); - - // Now that an entry has been found, verify the corrector's state. - assert.isNotOk(instance.predictableToken); // should become an uncorrectable. - assert.isTrue(instance.generatedTokenResults.has(therefxyz)); - assert.equal(instance.generatedTokenResults.get(therefxyz), tokenResults[0]); - } - - // There should be no further possible suggestions. - searchResult = instance.handleNextNode(); assert.equal(searchResult.type, 'none'); }); @@ -411,7 +395,7 @@ describe('TokenizationCorrector', () => { let firstResults: ReadonlyArray; if(searchResult.type == 'complete') { const mapping = searchResult.mapping; - const tokenResults = mapping.matchedResult; + const tokenResults = mapping.matchedResult.tokenCorrections; firstResults = tokenResults; assert.isNotNaN(searchResult.cost); assert.equal(searchResult.cost, searchResult.mapping.totalCost); @@ -434,7 +418,7 @@ describe('TokenizationCorrector', () => { searchResult = instance.handleNextNode(); if(searchResult.type == 'complete') { const mapping = searchResult.mapping; - const tokenResults = mapping.matchedResult; + const tokenResults = mapping.matchedResult.tokenCorrections; // Verify that the first (bound) token is not altered further. // It should receive no further correction attempts. @@ -445,7 +429,7 @@ describe('TokenizationCorrector', () => { } while(searchResult.type != 'none'); }); - it('immediately returns a single result when the only represented token is uncorrectable', () => { + it('immediately returns with no result when the only represented token is uncorrectable', () => { const fixture = buildFixture_terminalWhitespace(); const tokenization = fixture.spaceOnly; @@ -457,13 +441,7 @@ describe('TokenizationCorrector', () => { ); const searchResult = instance.handleNextNode(); - assert.equal(searchResult.type, 'complete'); - if(searchResult.type == 'complete') { - assert.equal(searchResult.mapping.matchedResult[0].matchString, ' '); - } - - const nilResult = instance.handleNextNode(); - assert.equal(nilResult.type, 'none'); + assert.equal(searchResult.type, 'none'); }); it('returns a single result when the final token is uncorrectable', () => { @@ -484,8 +462,8 @@ describe('TokenizationCorrector', () => { assert.equal(searchResult.type, 'complete'); if(searchResult.type == 'complete') { - assert.equal(searchResult.mapping.matchedResult[0].matchString, 'space'); - assert.equal(searchResult.mapping.matchedResult[1].matchString, ' '); + assert.equal(searchResult.mapping.matchedResult.tokenCorrections[0].matchString, 'space'); + assert.equal(searchResult.mapping.matchedResult.tokenCorrections[1].matchString, ' '); } const nilResult = instance.handleNextNode(); @@ -502,20 +480,20 @@ describe('TokenizationCorrector', () => { let haveSeenSingleTokenCorrection = false; let haveSeenThreeTokenCorrection = false; for await(let phraseMatch of getBestMatches< - ReadonlyArray, + TokenizationResult, TokenizationResultMapping, TokenizationCorrector >(correctors, buildTestTimer())) { - if(phraseMatch.matchedResult.length == 1) { + if(phraseMatch.matchedResult.tokenCorrections.length == 1) { if(!haveSeenSingleTokenCorrection) { - assert.sameOrderedMembers(phraseMatch.matchedResult.map((t) => t.matchString), ['theref' /* -ore */]); + assert.sameOrderedMembers(phraseMatch.matchedResult.tokenCorrections.map((t) => t.matchString), ['theref' /* -ore */]); } haveSeenSingleTokenCorrection = true; - } else if(phraseMatch.matchedResult.length == 3) { + } else if(phraseMatch.matchedResult.tokenCorrections.length == 3) { if(!haveSeenThreeTokenCorrection) { - assert.sameOrderedMembers(phraseMatch.matchedResult.map((t) => t.matchString), ['the', ' ', 'ef' /* -fort */]); + assert.sameOrderedMembers(phraseMatch.matchedResult.tokenCorrections.map((t) => t.matchString), ['the', ' ', 'ef' /* -fort */]); } haveSeenThreeTokenCorrection = true; } diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-correction-sequence.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-correction-sequence.tests.ts index 0e9d02f8f72..eb4f7ac5d9f 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-correction-sequence.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-correction-sequence.tests.ts @@ -542,8 +542,9 @@ describe('predictFromCorrectionSequence', () => { ]; const expected_prediction_p = dummied_suggestion_sequences - .map((dist) => { - return dist[0] + .map((dist, i) => { + // There is no valid 'g' entry corresponding to token index 0. + return i == 0 ? null : dist[0] }).reduce((accum, curr) => { return accum * (curr ? curr.p : Math.exp(-EDIT_DISTANCE_COST_SCALE)) }, 1); @@ -551,11 +552,11 @@ describe('predictFromCorrectionSequence', () => { const expected_predictions: Suggestion[] = [ { transform: { - insert: 'golden', + insert: 'g', deleteLeft: 0, id: transitionID }, - displayAs: 'golden', + displayAs: 'g', transformId: transitionID }, { transform: { @@ -589,7 +590,7 @@ describe('predictFromCorrectionSequence', () => { predictions.forEach((entry) => assert.equal(entry.metadata.probabilities.correction, parameters.tokens.reduce((accum, curr) => accum * curr.correction.p, 1))); predictions.sort(tupleDisplayOrderSort); - assert.deepEqual(predictions[0].components.map((c) => c.prediction.transform.insert), ['golden', ' ', 'apple']); + assert.deepEqual(predictions[0].components.map((c) => c.prediction.transform.insert), ['g', ' ', 'apple']); assert.sameDeepOrderedMembers(predictions[0].components.map((entry) => entry.prediction), expected_predictions); assert.approximately(predictions[0].metadata.probabilities.prediction, expected_prediction_p, 0.00001);