Skip to content

Commit d5d057d

Browse files
authored
fix(lib): Fix splitTest not correctly translating to test data length (#17)
* refactor(lib): Liberate splitTestData and improve its tests * fix(lib): Fix splitTest not correctly translating to test data length
1 parent e9646b3 commit d5d057d

File tree

3 files changed

+105
-18
lines changed

3 files changed

+105
-18
lines changed

src/loadCsv.ts

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,10 @@ import { shuffle } from 'shuffle-seed';
55

66
import { CsvReadOptions, CsvTable } from './loadCsv.models';
77
import filterColumns from './filterColumns';
8+
import splitTestData from './splitTestData';
89

910
const defaultShuffleSeed = 'mncv9340ur';
1011

11-
const splitTestData = (
12-
features: CsvTable,
13-
labels: CsvTable,
14-
splitTest: boolean | number
15-
) => {
16-
const length =
17-
typeof splitTest === 'number'
18-
? Math.max(0, Math.min(splitTest, features.length - 1))
19-
: Math.floor(features.length / 2);
20-
21-
return {
22-
testFeatures: features.slice(length),
23-
testLabels: labels.slice(length),
24-
features: features.slice(0, length),
25-
labels: labels.slice(0, length),
26-
};
27-
};
28-
2912
const loadCsv = (filename: string, options: CsvReadOptions) => {
3013
const {
3114
featureColumns,

src/splitTestData.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { CsvTable } from './loadCsv.models';
2+
3+
const splitTestData = (
4+
features: CsvTable,
5+
labels: CsvTable,
6+
splitTest: true | number
7+
) => {
8+
const dataLength = features.length;
9+
const testLength =
10+
typeof splitTest === 'number'
11+
? Math.max(0, Math.min(splitTest, dataLength))
12+
: Math.floor(features.length / 2);
13+
const testStartIndex = dataLength - testLength;
14+
15+
return {
16+
features: features.slice(0, testStartIndex),
17+
labels: labels.slice(0, testStartIndex),
18+
testFeatures: features.slice(testStartIndex),
19+
testLabels: labels.slice(testStartIndex),
20+
};
21+
};
22+
23+
export default splitTestData;

tests/splitTestData.test.ts

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import splitTestData from '../src/splitTestData';
2+
3+
const tables = {
4+
features: [
5+
[1, 2],
6+
[3, 4],
7+
[5, 6],
8+
[7, 8],
9+
],
10+
labels: [[9], [10], [11], [12]],
11+
};
12+
13+
test('Default splitting, splits in half', () => {
14+
const { features, labels, testFeatures, testLabels } = splitTestData(
15+
tables.features,
16+
tables.labels,
17+
true
18+
);
19+
expect(features).toMatchObject([
20+
[1, 2],
21+
[3, 4],
22+
]);
23+
expect(labels).toMatchObject([[9], [10]]);
24+
expect(testFeatures).toMatchObject([
25+
[5, 6],
26+
[7, 8],
27+
]);
28+
expect(testLabels).toMatchObject([[11], [12]]);
29+
});
30+
31+
test('Splitting a fixed amount works', () => {
32+
const { features, labels, testFeatures, testLabels } = splitTestData(
33+
tables.features,
34+
tables.labels,
35+
1
36+
);
37+
expect(features).toMatchObject([
38+
[1, 2],
39+
[3, 4],
40+
[5, 6],
41+
]);
42+
expect(labels).toMatchObject([[9], [10], [11]]);
43+
expect(testFeatures).toMatchObject([[7, 8]]);
44+
expect(testLabels).toMatchObject([[12]]);
45+
});
46+
47+
test('Splitting more than row length splits all rows into test data', () => {
48+
const { features, labels, testFeatures, testLabels } = splitTestData(
49+
tables.features,
50+
tables.labels,
51+
tables.features.length * 2
52+
);
53+
expect(features).toMatchObject([]);
54+
expect(labels).toMatchObject([]);
55+
expect(testFeatures).toMatchObject([
56+
[1, 2],
57+
[3, 4],
58+
[5, 6],
59+
[7, 8],
60+
]);
61+
expect(testLabels).toMatchObject([[9], [10], [11], [12]]);
62+
});
63+
64+
test('Splitting less than or equal to 0 places all rows into normal data', () => {
65+
[0, -1].forEach((splitLength) => {
66+
const { features, labels, testFeatures, testLabels } = splitTestData(
67+
tables.features,
68+
tables.labels,
69+
splitLength
70+
);
71+
expect(features).toMatchObject([
72+
[1, 2],
73+
[3, 4],
74+
[5, 6],
75+
[7, 8],
76+
]);
77+
expect(labels).toMatchObject([[9], [10], [11], [12]]);
78+
expect(testFeatures).toMatchObject([]);
79+
expect(testLabels).toMatchObject([]);
80+
});
81+
});

0 commit comments

Comments
 (0)