Skip to content

Commit 0637715

Browse files
committed
refactor: replace LocalOAIEngine with LocalAIEngine
Migrate the legacy `LocalOAIEngine` abstraction to a new, generic `LocalAIEngine` base. - Move all OpenAI type definitions to `LocalAIEngineTypes` and expose them via the engine index. - Remove the old `LocalOAIEngine` files (and its test) and replace them with `LocalAIEngine` that defines the abstract API for local inference providers. - Update the `llamacpp-extension` to extend `LocalAIEngine` instead of `AIEngine`. - Adjust exports to include the new engine and types, and clean up unused imports. - Add a new test suite for `LocalAIEngine` ensuring the abstract contract is satisfied.
1 parent 6afeed8 commit 0637715

File tree

8 files changed

+376
-454
lines changed

8 files changed

+376
-454
lines changed
Lines changed: 0 additions & 276 deletions
Original file line numberDiff line numberDiff line change
@@ -1,223 +1,6 @@
11
import { BaseExtension } from '../../extension'
22
import { EngineManager } from './EngineManager'
33

4-
/* AIEngine class types */
5-
6-
export interface chatCompletionRequestMessage {
7-
role: 'system' | 'user' | 'assistant' | 'tool'
8-
content: string | null | Content[] // Content can be a string OR an array of content parts
9-
reasoning?: string | null // Some models return reasoning in completed responses
10-
reasoning_content?: string | null // Some models return reasoning in completed responses
11-
name?: string
12-
tool_calls?: any[] // Simplified tool_call_id?: string
13-
}
14-
15-
export interface Content {
16-
type: 'text' | 'image_url' | 'input_audio'
17-
text?: string
18-
image_url?: string
19-
input_audio?: InputAudio
20-
}
21-
22-
export interface InputAudio {
23-
data: string // Base64 encoded audio data
24-
format: 'mp3' | 'wav' | 'ogg' | 'flac' // Add more formats as needed/llama-server seems to support mp3
25-
}
26-
27-
export interface ToolFunction {
28-
name: string // Required: a-z, A-Z, 0-9, _, -, max length 64
29-
description?: string
30-
parameters?: Record<string, unknown> // JSON Schema object
31-
strict?: boolean | null // Defaults to false
32-
}
33-
34-
export interface Tool {
35-
type: 'function' // Currently, only 'function' is supported
36-
function: ToolFunction
37-
}
38-
39-
export interface ToolCallOptions {
40-
tools?: Tool[]
41-
}
42-
43-
// A specific tool choice to force the model to call
44-
export interface ToolCallSpec {
45-
type: 'function'
46-
function: {
47-
name: string
48-
}
49-
}
50-
51-
// tool_choice may be one of several modes or a specific call
52-
export type ToolChoice = 'none' | 'auto' | 'required' | ToolCallSpec
53-
54-
export interface chatCompletionRequest {
55-
model: string // Model ID, though for local it might be implicit via sessionInfo
56-
messages: chatCompletionRequestMessage[]
57-
thread_id?: string // Thread/conversation ID for context tracking
58-
return_progress?: boolean
59-
tools?: Tool[]
60-
tool_choice?: ToolChoice
61-
// Core sampling parameters
62-
temperature?: number | null
63-
dynatemp_range?: number | null
64-
dynatemp_exponent?: number | null
65-
top_k?: number | null
66-
top_p?: number | null
67-
min_p?: number | null
68-
typical_p?: number | null
69-
repeat_penalty?: number | null
70-
repeat_last_n?: number | null
71-
presence_penalty?: number | null
72-
frequency_penalty?: number | null
73-
dry_multiplier?: number | null
74-
dry_base?: number | null
75-
dry_allowed_length?: number | null
76-
dry_penalty_last_n?: number | null
77-
dry_sequence_breakers?: string[] | null
78-
xtc_probability?: number | null
79-
xtc_threshold?: number | null
80-
mirostat?: number | null // 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0
81-
mirostat_tau?: number | null
82-
mirostat_eta?: number | null
83-
84-
n_predict?: number | null
85-
n_indent?: number | null
86-
n_keep?: number | null
87-
stream?: boolean | null
88-
stop?: string | string[] | null
89-
seed?: number | null // RNG seed
90-
91-
// Advanced sampling
92-
logit_bias?: { [key: string]: number } | null
93-
n_probs?: number | null
94-
min_keep?: number | null
95-
t_max_predict_ms?: number | null
96-
image_data?: Array<{ data: string; id: number }> | null
97-
98-
// Internal/optimization parameters
99-
id_slot?: number | null
100-
cache_prompt?: boolean | null
101-
return_tokens?: boolean | null
102-
samplers?: string[] | null
103-
timings_per_token?: boolean | null
104-
post_sampling_probs?: boolean | null
105-
chat_template_kwargs?: chat_template_kdict | null
106-
}
107-
108-
export interface chat_template_kdict {
109-
enable_thinking: false
110-
}
111-
112-
export interface chatCompletionChunkChoiceDelta {
113-
content?: string | null
114-
role?: 'system' | 'user' | 'assistant' | 'tool'
115-
tool_calls?: any[] // Simplified
116-
}
117-
118-
export interface chatCompletionChunkChoice {
119-
index: number
120-
delta: chatCompletionChunkChoiceDelta
121-
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null
122-
}
123-
124-
export interface chatCompletionPromptProgress {
125-
cache: number
126-
processed: number
127-
time_ms: number
128-
total: number
129-
}
130-
131-
export interface chatCompletionChunk {
132-
id: string
133-
object: 'chat.completion.chunk'
134-
created: number
135-
model: string
136-
choices: chatCompletionChunkChoice[]
137-
system_fingerprint?: string
138-
prompt_progress?: chatCompletionPromptProgress
139-
}
140-
141-
export interface chatCompletionChoice {
142-
index: number
143-
message: chatCompletionRequestMessage // Response message
144-
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call'
145-
logprobs?: any // Simplified
146-
}
147-
148-
export interface chatCompletion {
149-
id: string
150-
object: 'chat.completion'
151-
created: number
152-
model: string // Model ID used
153-
choices: chatCompletionChoice[]
154-
usage?: {
155-
prompt_tokens: number
156-
completion_tokens: number
157-
total_tokens: number
158-
}
159-
system_fingerprint?: string
160-
}
161-
// --- End OpenAI types ---
162-
163-
// Shared model metadata
164-
export interface modelInfo {
165-
id: string // e.g. "qwen3-4B" or "org/model/quant"
166-
name: string // human‑readable, e.g., "Qwen3 4B Q4_0"
167-
quant_type?: string // q4_0 (optional as it might be part of ID or name)
168-
providerId: string // e.g. "llama.cpp"
169-
port: number
170-
sizeBytes: number
171-
tags?: string[]
172-
path?: string // Absolute path to the model file, if applicable
173-
// Additional provider-specific metadata can be added here
174-
[key: string]: any
175-
}
176-
177-
// 1. /list
178-
export type listResult = modelInfo[]
179-
180-
export interface SessionInfo {
181-
pid: number // opaque handle for unload/chat
182-
port: number // llama-server output port (corrected from portid)
183-
model_id: string //name of the model
184-
model_path: string // path of the loaded model
185-
is_embedding: boolean
186-
api_key: string
187-
mmproj_path?: string
188-
}
189-
190-
export interface UnloadResult {
191-
success: boolean
192-
error?: string
193-
}
194-
195-
// 5. /chat
196-
export interface chatOptions {
197-
providerId: string
198-
sessionId: string
199-
/** Full OpenAI ChatCompletionRequest payload */
200-
payload: chatCompletionRequest
201-
}
202-
// Output for /chat will be Promise<ChatCompletion> for non-streaming
203-
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
204-
205-
// 7. /import
206-
export interface ImportOptions {
207-
modelPath: string
208-
mmprojPath?: string
209-
modelSha256?: string
210-
modelSize?: number
211-
mmprojSha256?: string
212-
mmprojSize?: number
213-
}
214-
215-
export interface importResult {
216-
success: boolean
217-
modelInfo?: modelInfo
218-
error?: string
219-
}
220-
2214
/**
2225
* Base AIEngine
2236
* Applicable to all AI Engines
@@ -240,63 +23,4 @@ export abstract class AIEngine extends BaseExtension {
24023
registerEngine() {
24124
EngineManager.instance().register(this)
24225
}
243-
244-
/**
245-
* Gets model info
246-
* @param modelId
247-
*/
248-
abstract get(modelId: string): Promise<modelInfo | undefined>
249-
250-
/**
251-
* Lists available models
252-
*/
253-
abstract list(): Promise<modelInfo[]>
254-
255-
/**
256-
* Loads a model into memory
257-
*/
258-
abstract load(modelId: string, settings?: any): Promise<SessionInfo>
259-
260-
/**
261-
* Unloads a model from memory
262-
*/
263-
abstract unload(sessionId: string): Promise<UnloadResult>
264-
265-
/**
266-
* Sends a chat request to the model
267-
*/
268-
abstract chat(
269-
opts: chatCompletionRequest,
270-
abortController?: AbortController
271-
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>
272-
273-
/**
274-
* Deletes a model
275-
*/
276-
abstract delete(modelId: string): Promise<void>
277-
278-
/**
279-
* Updates a model
280-
*/
281-
abstract update(modelId: string, model: Partial<modelInfo>): Promise<void>
282-
/**
283-
* Imports a model
284-
*/
285-
abstract import(modelId: string, opts: ImportOptions): Promise<void>
286-
287-
/**
288-
* Aborts an ongoing model import
289-
*/
290-
abstract abortImport(modelId: string): Promise<void>
291-
292-
/**
293-
* Get currently loaded models
294-
*/
295-
abstract getLoadedModels(): Promise<string[]>
296-
297-
/**
298-
* Check if a tool is supported by the model
299-
* @param modelId
300-
*/
301-
abstract isToolSupported(modelId: string): Promise<boolean>
30226
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import { describe, it, expect, beforeEach, vi } from 'vitest'
2+
import { LocalAIEngine } from './LocalAIEngine'
3+
4+
class TestLocalAIEngine extends LocalAIEngine {
5+
provider = 'test-provider'
6+
7+
async onUnload(): Promise<void> {}
8+
9+
async get() {
10+
return undefined
11+
}
12+
async list() {
13+
return []
14+
}
15+
async load() {
16+
return {} as any
17+
}
18+
async unload() {
19+
return {} as any
20+
}
21+
async chat() {
22+
return {} as any
23+
}
24+
async delete() {}
25+
async update() {}
26+
async import() {}
27+
async abortImport() {}
28+
async getLoadedModels() {
29+
return []
30+
}
31+
async isToolSupported() {
32+
return false
33+
}
34+
}
35+
36+
describe('LocalAIEngine', () => {
37+
let engine: TestLocalAIEngine
38+
39+
beforeEach(() => {
40+
engine = new TestLocalAIEngine('', '')
41+
vi.clearAllMocks()
42+
})
43+
44+
describe('onLoad', () => {
45+
it('should call super.onLoad', async () => {
46+
const superOnLoadSpy = vi.spyOn(
47+
Object.getPrototypeOf(Object.getPrototypeOf(engine)),
48+
'onLoad'
49+
)
50+
51+
await engine.onLoad()
52+
53+
expect(superOnLoadSpy).toHaveBeenCalled()
54+
})
55+
})
56+
57+
describe('abstract requirements', () => {
58+
it('should implement provider', () => {
59+
expect(engine.provider).toBe('test-provider')
60+
})
61+
62+
it('should implement abstract methods', async () => {
63+
expect(await engine.get('id')).toBeUndefined()
64+
expect(await engine.list()).toEqual([])
65+
expect(await engine.getLoadedModels()).toEqual([])
66+
expect(await engine.isToolSupported('id')).toBe(false)
67+
})
68+
})
69+
})

0 commit comments

Comments
 (0)