Skip to content

Commit fd09e7e

Browse files
committed
refactor: replace simplifySchema with direct schema question functions
- Updated multiple files to remove the use of simplifySchema, replacing it with direct access to schema properties. - This change enhances performance and simplifies the codebase by eliminating unnecessary schema simplification calls. - Adjusted relevant logic in AiTask, FsFolderKvRepository, KvViaTabularRepository, BaseSqlTabularRepository, Dataflow, and Workflow classes to reflect this update.
1 parent 3fca7ae commit fd09e7e

File tree

8 files changed

+180
-51
lines changed

8 files changed

+180
-51
lines changed

packages/ai/src/task/base/AiTask.ts

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import {
1717
type TaskInput,
1818
type TaskOutput,
1919
} from "@ellmers/task-graph";
20-
import { simplifySchema } from "@ellmers/util";
20+
import { schemaSemantic } from "@ellmers/util";
2121
import { type TSchema } from "@sinclair/typebox";
2222
import { AiJob } from "../../job/AiJob";
2323
import { getGlobalModelRepository } from "../../model/ModelRegistry";
@@ -87,9 +87,8 @@ export class AiTask<
8787
*/
8888
async validateInput(input: Input): Promise<boolean> {
8989
// TODO(str): this is very inefficient, we should cache the results, including intermediate results
90-
const inputSchemaProperties = simplifySchema(this.inputSchema).properties;
91-
const modelTaskProperties = Object.entries<TSchema>(inputSchemaProperties).filter(
92-
([key, value]) => value.semantic?.startsWith("model:")
90+
const modelTaskProperties = Object.entries<TSchema>(this.inputSchema.properties).filter(
91+
([key, schema]) => schemaSemantic(schema)?.startsWith("model:")
9392
);
9493
if (modelTaskProperties.length > 0) {
9594
const taskModels = await getGlobalModelRepository().findModelsByTask(this.type);
@@ -103,8 +102,8 @@ export class AiTask<
103102
}
104103
}
105104
}
106-
const modelPlainProperties = Object.entries<TSchema>(inputSchemaProperties).filter(
107-
([key, value]) => value.semantic === "model"
105+
const modelPlainProperties = Object.entries<TSchema>(this.inputSchema.properties).filter(
106+
([key, schema]) => schemaSemantic(schema) === "model"
108107
);
109108
if (modelPlainProperties.length > 0) {
110109
for (const [key, propSchema] of modelPlainProperties) {
@@ -125,9 +124,8 @@ export class AiTask<
125124
// if all of them are stripped, then the task will fail in validateInput
126125
async narrowInput(input: Input): Promise<Input> {
127126
// TODO(str): this is very inefficient, we should cache the results, including intermediate results
128-
const inputSchemaProperties = simplifySchema(this.inputSchema).properties;
129-
const modelTaskProperties = Object.entries<TSchema>(inputSchemaProperties).filter(
130-
([key, value]) => value.semantic?.startsWith("model:")
127+
const modelTaskProperties = Object.entries<TSchema>(this.inputSchema.properties).filter(
128+
([key, schema]) => schemaSemantic(schema)?.startsWith("model:")
131129
);
132130
if (modelTaskProperties.length > 0) {
133131
const taskModels = await getGlobalModelRepository().findModelsByTask(this.type);

packages/storage/src/kv/FsFolderKvRepository.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// * Licensed under the Apache License, Version 2.0 (the "License"); *
66
// *******************************************************************************
77

8-
import { createServiceToken, simplifySchema, TypeBlob } from "@ellmers/util";
8+
import { createServiceToken, TypeBlob } from "@ellmers/util";
99
import { TSchema, Type } from "@sinclair/typebox";
1010
import { mkdir, readFile, rmdir, unlink, writeFile } from "fs/promises";
1111
import path from "path";
@@ -52,7 +52,7 @@ export class FsFolderKvRepository<
5252
let content: string;
5353
if (value === null) {
5454
content = "";
55-
} else if (simplifySchema(this.valueSchema).type === "object") {
55+
} else if (this.valueSchema.type === "object") {
5656
content = JSON.stringify(value);
5757
} else if (typeof value === "object") {
5858
// Handle 'json' type schema from tests
@@ -74,7 +74,7 @@ export class FsFolderKvRepository<
7474
*/
7575
public async get(key: Key): Promise<Value | undefined> {
7676
const localPath = path.join(this.folderPath, this.pathWriter(key));
77-
const typeDef = simplifySchema(this.valueSchema);
77+
const typeDef = this.valueSchema;
7878
try {
7979
const encoding = typeDef.contentEncoding === "blob" ? "binary" : "utf-8";
8080
const content = (await readFile(localPath, { encoding })).trim();

packages/storage/src/kv/KvViaTabularRepository.ts

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// * Licensed under the Apache License, Version 2.0 (the "License"); *
66
// *******************************************************************************
77

8-
import { simplifySchema } from "@ellmers/util";
98
import { Static } from "@sinclair/typebox";
109
import type { TabularRepository } from "../tabular/TabularRepository";
1110
import { DefaultKeyValueKey, DefaultKeyValueSchema } from "./IKvRepository";
@@ -39,7 +38,7 @@ export abstract class KvViaTabularRepository<
3938
public async put(key: Key, value: Value): Promise<void> {
4039
// Handle objects that need to be JSON-stringified, TODO(str): should put in the type
4140
const shouldStringify = !["number", "boolean", "string", "blob"].includes(
42-
simplifySchema(this.valueSchema).type
41+
this.valueSchema.type
4342
);
4443

4544
if (shouldStringify) {
@@ -58,9 +57,7 @@ export abstract class KvViaTabularRepository<
5857
public async get(key: Key): Promise<Value | undefined> {
5958
const result = await this.tabularRepository.get({ key });
6059
if (result) {
61-
const shouldParse = !["number", "boolean", "string", "blob"].includes(
62-
simplifySchema(this.valueSchema).type
63-
);
60+
const shouldParse = !["number", "boolean", "string", "blob"].includes(this.valueSchema.type);
6461

6562
if (shouldParse) {
6663
try {
@@ -96,9 +93,7 @@ export abstract class KvViaTabularRepository<
9693
({
9794
key: value.key,
9895
value: (() => {
99-
const shouldParse = !["number", "boolean", "string"].includes(
100-
simplifySchema(this.valueSchema).type
101-
);
96+
const shouldParse = !["number", "boolean", "string"].includes(this.valueSchema.type);
10297

10398
if (shouldParse && typeof value.value === "string") {
10499
try {

packages/storage/src/tabular/BaseSqlTabularRepository.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// * Licensed under the Apache License, Version 2.0 (the "License"); *
66
// *******************************************************************************
77

8-
import { simplifySchema } from "@ellmers/util";
98
import { Static, TObject, TSchema } from "@sinclair/typebox";
109
import { ExtractPrimaryKey, ExtractValue, ValueOptionType } from "./ITabularRepository";
1110
import { TabularRepository } from "./TabularRepository";

packages/task-graph/src/task-graph/Dataflow.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
// * Licensed under the Apache License, Version 2.0 (the "License"); *
66
// *******************************************************************************
77

8-
import { areSemanticallyCompatible, EventEmitter, simplifySchema } from "@ellmers/util";
8+
import { areSemanticallyCompatible, EventEmitter } from "@ellmers/util";
9+
import { Type } from "@sinclair/typebox";
910
import { TaskError } from "../task/TaskError";
1011
import { DataflowJson } from "../task/TaskJSON";
1112
import { Provenance, TaskIdType, TaskOutput, TaskStatus } from "../task/TaskTypes";
@@ -16,7 +17,6 @@ import {
1617
DataflowEvents,
1718
} from "./DataflowEvents";
1819
import { TaskGraph } from "./TaskGraph";
19-
import { Type } from "@sinclair/typebox";
2020

2121
export type DataflowIdType = `${string}[${string}] ==> ${string}[${string}]`;
2222

@@ -117,11 +117,11 @@ export class Dataflow {
117117
const targetSchemaProperty =
118118
DATAFLOW_ALL_PORTS === dataflow.targetTaskPortId
119119
? Type.Any()
120-
: simplifySchema(targetSchema.properties[dataflow.targetTaskPortId]);
120+
: targetSchema.properties[dataflow.targetTaskPortId];
121121
const sourceSchemaProperty =
122122
DATAFLOW_ALL_PORTS === dataflow.sourceTaskPortId
123123
? Type.Any()
124-
: simplifySchema(sourceSchema.properties[dataflow.sourceTaskPortId]);
124+
: sourceSchema.properties[dataflow.sourceTaskPortId];
125125

126126
const semanticallyCompatible = areSemanticallyCompatible(
127127
sourceSchemaProperty,

packages/task-graph/src/task-graph/Workflow.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// * Licensed under the Apache License, Version 2.0 (the "License"); *
66
// *******************************************************************************
77

8-
import { EventEmitter, simplifySchema, type EventParameters } from "@ellmers/util";
8+
import { EventEmitter, type EventParameters } from "@ellmers/util";
99
import { TObject, TSchema } from "@sinclair/typebox";
1010
import { TaskOutputRepository } from "../storage/TaskOutputRepository";
1111
import { GraphAsTask } from "../task/GraphAsTask";
@@ -18,7 +18,7 @@ import { getLastTask, parallel, pipe, PipeFunction, Taskish } from "./Conversion
1818
import { Dataflow, DATAFLOW_ALL_PORTS } from "./Dataflow";
1919
import { IWorkflow } from "./IWorkflow";
2020
import { TaskGraph } from "./TaskGraph";
21-
import { CompoundMergeStrategy, GraphResultMap } from "./TaskGraphRunner";
21+
import { CompoundMergeStrategy } from "./TaskGraphRunner";
2222

2323
// Type definitions for the workflow
2424
export type CreateWorkflow<I extends TaskIO, O extends TaskIO, C extends TaskConfig> = (
@@ -137,8 +137,8 @@ export class Workflow<Input extends TaskIO = TaskIO, Output extends TaskIO = Tas
137137
if (parent && this.graph.getTargetDataflows(parent.config.id).length === 0) {
138138
// Find matches between parent outputs and task inputs based on valueType
139139
const matches = new Map<string, string>();
140-
const sourceSchema = simplifySchema(parent.outputSchema) as TObject;
141-
const targetSchema = simplifySchema(task.inputSchema) as TObject;
140+
const sourceSchema = parent.outputSchema as TObject;
141+
const targetSchema = task.inputSchema as TObject;
142142

143143
const makeMatch = (
144144
comparator: (
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import { OptionalKind, Type } from "@sinclair/typebox";
2+
import { describe, expect, test } from "bun:test";
3+
import { simplifySchema } from "@ellmers/util";
4+
5+
describe("simplifySchema", () => {
6+
test("should throw error for undefined schema", () => {
7+
expect(() => simplifySchema(undefined as any)).toThrow("Schema is undefined");
8+
});
9+
10+
test("should return Any schema as is", () => {
11+
const schema = Type.Any();
12+
expect(simplifySchema(schema)).toEqual(schema);
13+
});
14+
15+
test("should simplify union of base type and array of same type", () => {
16+
const schema = Type.Union([Type.String(), Type.Array(Type.String())]);
17+
const result = simplifySchema(schema);
18+
expect(result).toEqual(Type.String({ isArray: true }));
19+
});
20+
21+
test("should preserve annotations when simplifying union", () => {
22+
const schema1 = Type.Union([Type.String(), Type.Array(Type.String())], {
23+
title: "MyString",
24+
description: "A string or array of strings",
25+
});
26+
const result1 = simplifySchema(schema1);
27+
expect(result1).toEqual(
28+
Type.String({
29+
title: "MyString",
30+
description: "A string or array of strings",
31+
isArray: true,
32+
})
33+
);
34+
const schema2 = Type.Union([
35+
Type.String({
36+
title: "MyString",
37+
description: "A string or array of strings",
38+
}),
39+
Type.Array(Type.String()),
40+
]);
41+
const result2 = simplifySchema(schema2);
42+
expect(result2).toEqual(
43+
Type.String({
44+
title: "MyString",
45+
description: "A string or array of strings",
46+
isArray: true,
47+
})
48+
);
49+
});
50+
51+
test("should handle nullable types", () => {
52+
const schema = Type.Union([Type.String(), Type.Null()]);
53+
const result = simplifySchema(schema);
54+
expect(result).toEqual(
55+
Type.String({
56+
isNullable: true,
57+
default: null,
58+
})
59+
);
60+
});
61+
62+
test("should recursively simplify object properties", () => {
63+
const schema = Type.Object({
64+
name: Type.Union([Type.String(), Type.Array(Type.String())]),
65+
age: Type.Union([Type.Number(), Type.Null()]),
66+
});
67+
const result = simplifySchema(schema);
68+
expect(result).toEqual(
69+
Type.Object({
70+
name: Type.String({ isArray: true }),
71+
age: Type.Number({ isNullable: true, default: null }),
72+
})
73+
);
74+
});
75+
76+
test("should recursively simplify array items", () => {
77+
const schema = Type.Array(Type.Union([Type.String(), Type.Array(Type.String())]));
78+
const result = simplifySchema(schema);
79+
expect(result).toEqual(Type.Array(Type.String({ isArray: true })));
80+
});
81+
82+
test("should preserve optional flag and default values", () => {
83+
const schema = Type.Object({
84+
name: Type.Optional(Type.String({ default: "John" })),
85+
});
86+
const result = simplifySchema(schema);
87+
expect(result).toEqual(
88+
Type.Object({
89+
name: Type.String({
90+
[OptionalKind]: "Optional",
91+
optional: true,
92+
isNullable: true,
93+
default: "John",
94+
}),
95+
})
96+
);
97+
});
98+
99+
test("should handle complex nested structures", () => {
100+
const schema = Type.Object({
101+
user: Type.Object({
102+
name: Type.Union([Type.String(), Type.Array(Type.String())]),
103+
addresses: Type.Array(Type.Union([Type.String(), Type.Array(Type.String())])),
104+
}),
105+
});
106+
const result = simplifySchema(schema);
107+
expect(result).toEqual(
108+
Type.Object({
109+
user: Type.Object({
110+
name: Type.String({ isArray: true }),
111+
addresses: Type.Array(Type.String({ isArray: true })),
112+
}),
113+
})
114+
);
115+
});
116+
});

0 commit comments

Comments
 (0)