From 879a6abca120345b18ee2051a5eec0e348ff4c2f Mon Sep 17 00:00:00 2001 From: Elina Date: Sun, 19 Apr 2026 16:26:13 +0800 Subject: [PATCH] Harden memu retrieval and topic scoring fallbacks --- .../layered/__tests__/temporary-topic.test.ts | 61 +++++++++++++++++++ .../services/agent/context/layered/index.ts | 1 + .../agent/context/layered/temporary-topic.ts | 34 +++++++++++ src/main/services/proactive.service.ts | 21 +------ .../tools/__tests__/memu.executor.test.ts | 55 +++++++++++++++++ src/main/tools/memu.executor.ts | 30 ++++++++- 6 files changed, 182 insertions(+), 20 deletions(-) create mode 100644 src/main/services/agent/context/layered/__tests__/temporary-topic.test.ts create mode 100644 src/main/tools/__tests__/memu.executor.test.ts diff --git a/src/main/services/agent/context/layered/__tests__/temporary-topic.test.ts b/src/main/services/agent/context/layered/__tests__/temporary-topic.test.ts new file mode 100644 index 0000000..200be1e --- /dev/null +++ b/src/main/services/agent/context/layered/__tests__/temporary-topic.test.ts @@ -0,0 +1,61 @@ +import { describe, expect, it } from 'vitest' +import { + createHeuristicTopicScorer, + createLLMTopicClassifier, + decideTemporaryTopicTransition +} from '../temporary-topic' + +describe('createHeuristicTopicScorer', () => { + it('scores topic overlap without requiring an external API key', async () => { + const scorer = createHeuristicTopicScorer() + + const result = await scorer( + 'Show the project roadmap and client deadline', + 'We discussed the project roadmap and the client deadline yesterday', + 'Dinner plans and movie night' + ) + + expect(result.relMain).toBeGreaterThan(result.relTemp) + expect(result.relMain).toBeGreaterThan(0.5) + }) +}) + +describe('createLLMTopicClassifier', () => { + it('falls back to the heuristic scorer when no API key is configured', async () => { + const scorer = createLLMTopicClassifier({ apiKey: '' }) + const result = await scorer( + 'Back to the sprint roadmap and engineering deadline', + 'Project roadmap, sprint goals, engineering deadline', + 'Movie night, dinner reservation, popcorn flavors' + ) + + expect(result.relMain).toBeGreaterThan(result.relTemp) + }) +}) + +describe('decideTemporaryTopicTransition with heuristic scorer', () => { + it('enters a temporary topic when the query no longer matches the main thread', async () => { + const scorer = createHeuristicTopicScorer() + + const transition = await decideTemporaryTopicTransition({ + mode: 'MAIN', + query: 'What movie should we watch tonight?', + mainTopicReference: 'Project roadmap, sprint goals, engineering deadline' + }, scorer) + + expect(transition.decision).toBe('enter-temp') + }) + + it('exits a temporary topic when the query returns to the main thread', async () => { + const scorer = createHeuristicTopicScorer() + + const transition = await decideTemporaryTopicTransition({ + mode: 'TEMP', + query: 'Back to the sprint roadmap and engineering deadline', + mainTopicReference: 'Project roadmap, sprint goals, engineering deadline', + tempTopicReference: 'Movie night, dinner reservation, popcorn flavors' + }, scorer) + + expect(transition.decision).toBe('exit-temp') + }) +}) diff --git a/src/main/services/agent/context/layered/index.ts b/src/main/services/agent/context/layered/index.ts index 2c9e3ab..4267b1e 100644 --- a/src/main/services/agent/context/layered/index.ts +++ b/src/main/services/agent/context/layered/index.ts @@ -11,6 +11,7 @@ export { decideTemporaryTopicTransition, createLLMTopicScorer, createLLMTopicClassifier, + createHeuristicTopicScorer, DEFAULT_TEMPORARY_TOPIC_THRESHOLDS } from './temporary-topic' export type { diff --git a/src/main/services/agent/context/layered/temporary-topic.ts b/src/main/services/agent/context/layered/temporary-topic.ts index 6d69b37..bc88c95 100644 --- a/src/main/services/agent/context/layered/temporary-topic.ts +++ b/src/main/services/agent/context/layered/temporary-topic.ts @@ -55,6 +55,36 @@ export type TopicScorer = ( tempTopicReference: string ) => Promise +function buildTokenSet(value: string): Set { + return new Set( + normalizeWhitespace(value) + .toLowerCase() + .split(/[^a-z0-9]+/i) + .map((token) => token.trim()) + .filter((token) => token.length >= 3) + ) +} + +function estimateTopicOverlap(query: string, topicReference: string): number { + const queryTokens = buildTokenSet(query) + const topicTokens = buildTokenSet(topicReference) + if (queryTokens.size === 0 || topicTokens.size === 0) return 0 + + let overlap = 0 + for (const token of queryTokens) { + if (topicTokens.has(token)) overlap += 1 + } + + return clampScore(overlap / queryTokens.size) +} + +export function createHeuristicTopicScorer(): TopicScorer { + return async (query, mainTopicReference, tempTopicReference) => ({ + relMain: estimateTopicOverlap(query, mainTopicReference), + relTemp: estimateTopicOverlap(query, tempTopicReference) + }) +} + // ============================================ // LLM Topic Scorer // ============================================ @@ -207,6 +237,10 @@ function classificationToScores(decision: TemporaryTopicDecision): TopicRelevanc } export function createLLMTopicClassifier(options: LLMTopicScorerOptions): TopicScorer { + if (!options.apiKey.trim()) { + return createHeuristicTopicScorer() + } + const client = new Anthropic({ apiKey: options.apiKey }) diff --git a/src/main/services/proactive.service.ts b/src/main/services/proactive.service.ts index 9df4f18..cef7f62 100644 --- a/src/main/services/proactive.service.ts +++ b/src/main/services/proactive.service.ts @@ -19,6 +19,7 @@ import { slackBotService } from '../apps/slack/bot.service' import { whatsappBotService } from '../apps/whatsapp/bot.service' import { lineBotService } from '../apps/line/bot.service' import { localChatService } from '../apps/local' +import { executeMemuMemory as executeSharedMemuMemory } from '../tools/memu.executor' import type { AgentResponse } from '../types' /** @@ -200,25 +201,7 @@ class ProactiveService { * This tool use main user/agent ids to retrieve memory from the main service. */ private async executeMemuMemory(query: string): Promise<{ success: boolean; data?: unknown; error?: string }> { - try { - const memuConfig = await this.getMemuConfig() - const response = await fetch(`${memuConfig.baseUrl}/api/v3/memory/retrieve`, { - method: 'POST', - headers: { - 'Authorization': `Bearer ${memuConfig.apiKey}`, - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - user_id: memuConfig.userId, - agent_id: memuConfig.agentId, - query - }) - }) - const result = await response.json() - return { success: true, data: result } - } catch (error) { - return { success: false, error: error instanceof Error ? error.message : String(error) } - } + return executeSharedMemuMemory(query) } /** diff --git a/src/main/tools/__tests__/memu.executor.test.ts b/src/main/tools/__tests__/memu.executor.test.ts new file mode 100644 index 0000000..581d44c --- /dev/null +++ b/src/main/tools/__tests__/memu.executor.test.ts @@ -0,0 +1,55 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { loadSettingsMock } = vi.hoisted(() => ({ + loadSettingsMock: vi.fn() +})) + +vi.mock('../../config/settings.config', () => ({ + loadSettings: loadSettingsMock +})) + +import { executeMemuMemory } from '../memu.executor' + +describe('executeMemuMemory', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.unstubAllGlobals() + }) + + it('returns an error when remote memU config is incomplete', async () => { + loadSettingsMock.mockResolvedValue({ + memuBaseUrl: '', + memuApiKey: 'token-only', + memuUserId: '', + memuAgentId: '' + }) + + const result = await executeMemuMemory('trip plans') + + expect(result).toEqual({ + success: false, + error: 'memU is not fully configured' + }) + }) + + it('returns the remote error when retrieve returns a non-OK response', async () => { + loadSettingsMock.mockResolvedValue({ + memuBaseUrl: 'https://memu.example', + memuApiKey: 'token', + memuUserId: 'user-1', + memuAgentId: 'agent-1' + }) + vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ + ok: false, + status: 503, + json: async () => ({ message: 'service unavailable' }) + })) + + const result = await executeMemuMemory('roadmap') + + expect(result).toEqual({ + success: false, + error: 'service unavailable' + }) + }) +}) diff --git a/src/main/tools/memu.executor.ts b/src/main/tools/memu.executor.ts index f3921b5..8c4c35e 100644 --- a/src/main/tools/memu.executor.ts +++ b/src/main/tools/memu.executor.ts @@ -22,12 +22,25 @@ async function getMemuConfig(): Promise { } } +function hasRemoteConfig(config: MemuConfig): boolean { + return !!( + config.apiKey && config.apiKey.trim() && + config.baseUrl && config.baseUrl.trim() && + config.userId && config.userId.trim() && + config.agentId && config.agentId.trim() + ) +} + /** * Execute memu_memory: retrieve memory by query from the Memu API. */ export async function executeMemuMemory(query: string): Promise { try { const memuConfig = await getMemuConfig() + if (!hasRemoteConfig(memuConfig)) { + return { success: false, error: 'memU is not fully configured' } + } + const response = await fetch(`${memuConfig.baseUrl}/api/v3/memory/retrieve`, { method: 'POST', headers: { @@ -40,7 +53,22 @@ export async function executeMemuMemory(query: string): Promise { query }) }) - const result = await response.json() + + let result: unknown = null + try { + result = await response.json() + } catch { + result = null + } + + if (!response.ok) { + const message = + typeof result === 'object' && result && 'message' in result && typeof (result as { message?: unknown }).message === 'string' + ? (result as { message: string }).message + : `memU retrieve failed with HTTP ${response.status}` + return { success: false, error: message } + } + return { success: true, data: result } } catch (error) { return { success: false, error: error instanceof Error ? error.message : String(error) }