Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import { describe, expect, it } from 'vitest'
import {
createHeuristicTopicScorer,
createLLMTopicClassifier,
decideTemporaryTopicTransition
} from '../temporary-topic'

describe('createHeuristicTopicScorer', () => {
it('scores topic overlap without requiring an external API key', async () => {
const scorer = createHeuristicTopicScorer()

const result = await scorer(
'Show the project roadmap and client deadline',
'We discussed the project roadmap and the client deadline yesterday',
'Dinner plans and movie night'
)

expect(result.relMain).toBeGreaterThan(result.relTemp)
expect(result.relMain).toBeGreaterThan(0.5)
})
})

describe('createLLMTopicClassifier', () => {
it('falls back to the heuristic scorer when no API key is configured', async () => {
const scorer = createLLMTopicClassifier({ apiKey: '' })
const result = await scorer(
'Back to the sprint roadmap and engineering deadline',
'Project roadmap, sprint goals, engineering deadline',
'Movie night, dinner reservation, popcorn flavors'
)

expect(result.relMain).toBeGreaterThan(result.relTemp)
})
})

describe('decideTemporaryTopicTransition with heuristic scorer', () => {
it('enters a temporary topic when the query no longer matches the main thread', async () => {
const scorer = createHeuristicTopicScorer()

const transition = await decideTemporaryTopicTransition({
mode: 'MAIN',
query: 'What movie should we watch tonight?',
mainTopicReference: 'Project roadmap, sprint goals, engineering deadline'
}, scorer)

expect(transition.decision).toBe('enter-temp')
})

it('exits a temporary topic when the query returns to the main thread', async () => {
const scorer = createHeuristicTopicScorer()

const transition = await decideTemporaryTopicTransition({
mode: 'TEMP',
query: 'Back to the sprint roadmap and engineering deadline',
mainTopicReference: 'Project roadmap, sprint goals, engineering deadline',
tempTopicReference: 'Movie night, dinner reservation, popcorn flavors'
}, scorer)

expect(transition.decision).toBe('exit-temp')
})
})
1 change: 1 addition & 0 deletions src/main/services/agent/context/layered/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export {
decideTemporaryTopicTransition,
createLLMTopicScorer,
createLLMTopicClassifier,
createHeuristicTopicScorer,
DEFAULT_TEMPORARY_TOPIC_THRESHOLDS
} from './temporary-topic'
export type {
Expand Down
34 changes: 34 additions & 0 deletions src/main/services/agent/context/layered/temporary-topic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@ export type TopicScorer = (
tempTopicReference: string
) => Promise<TopicRelevanceScores>

function buildTokenSet(value: string): Set<string> {
return new Set(
normalizeWhitespace(value)
.toLowerCase()
.split(/[^a-z0-9]+/i)
.map((token) => token.trim())
.filter((token) => token.length >= 3)
)
}

function estimateTopicOverlap(query: string, topicReference: string): number {
const queryTokens = buildTokenSet(query)
const topicTokens = buildTokenSet(topicReference)
if (queryTokens.size === 0 || topicTokens.size === 0) return 0

let overlap = 0
for (const token of queryTokens) {
if (topicTokens.has(token)) overlap += 1
}

return clampScore(overlap / queryTokens.size)
}

export function createHeuristicTopicScorer(): TopicScorer {
return async (query, mainTopicReference, tempTopicReference) => ({
relMain: estimateTopicOverlap(query, mainTopicReference),
relTemp: estimateTopicOverlap(query, tempTopicReference)
})
}

// ============================================
// LLM Topic Scorer
// ============================================
Expand Down Expand Up @@ -207,6 +237,10 @@ function classificationToScores(decision: TemporaryTopicDecision): TopicRelevanc
}

export function createLLMTopicClassifier(options: LLMTopicScorerOptions): TopicScorer {
if (!options.apiKey.trim()) {
return createHeuristicTopicScorer()
}

const client = new Anthropic({
apiKey: options.apiKey
})
Expand Down
21 changes: 2 additions & 19 deletions src/main/services/proactive.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { slackBotService } from '../apps/slack/bot.service'
import { whatsappBotService } from '../apps/whatsapp/bot.service'
import { lineBotService } from '../apps/line/bot.service'
import { localChatService } from '../apps/local'
import { executeMemuMemory as executeSharedMemuMemory } from '../tools/memu.executor'
import type { AgentResponse } from '../types'

/**
Expand Down Expand Up @@ -200,25 +201,7 @@ class ProactiveService {
* This tool use main user/agent ids to retrieve memory from the main service.
*/
private async executeMemuMemory(query: string): Promise<{ success: boolean; data?: unknown; error?: string }> {
try {
const memuConfig = await this.getMemuConfig()
const response = await fetch(`${memuConfig.baseUrl}/api/v3/memory/retrieve`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${memuConfig.apiKey}`,
'Content-Type': 'application/json'
},
body: JSON.stringify({
user_id: memuConfig.userId,
agent_id: memuConfig.agentId,
query
})
})
const result = await response.json()
return { success: true, data: result }
} catch (error) {
return { success: false, error: error instanceof Error ? error.message : String(error) }
}
return executeSharedMemuMemory(query)
}

/**
Expand Down
55 changes: 55 additions & 0 deletions src/main/tools/__tests__/memu.executor.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'

const { loadSettingsMock } = vi.hoisted(() => ({
loadSettingsMock: vi.fn()
}))

vi.mock('../../config/settings.config', () => ({
loadSettings: loadSettingsMock
}))

import { executeMemuMemory } from '../memu.executor'

describe('executeMemuMemory', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.unstubAllGlobals()
})

it('returns an error when remote memU config is incomplete', async () => {
loadSettingsMock.mockResolvedValue({
memuBaseUrl: '',
memuApiKey: 'token-only',
memuUserId: '',
memuAgentId: ''
})

const result = await executeMemuMemory('trip plans')

expect(result).toEqual({
success: false,
error: 'memU is not fully configured'
})
})

it('returns the remote error when retrieve returns a non-OK response', async () => {
loadSettingsMock.mockResolvedValue({
memuBaseUrl: 'https://memu.example',
memuApiKey: 'token',
memuUserId: 'user-1',
memuAgentId: 'agent-1'
})
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: false,
status: 503,
json: async () => ({ message: 'service unavailable' })
}))

const result = await executeMemuMemory('roadmap')

expect(result).toEqual({
success: false,
error: 'service unavailable'
})
})
})
30 changes: 29 additions & 1 deletion src/main/tools/memu.executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,25 @@ async function getMemuConfig(): Promise<MemuConfig> {
}
}

function hasRemoteConfig(config: MemuConfig): boolean {
return !!(
config.apiKey && config.apiKey.trim() &&
config.baseUrl && config.baseUrl.trim() &&
config.userId && config.userId.trim() &&
config.agentId && config.agentId.trim()
)
}

/**
* Execute memu_memory: retrieve memory by query from the Memu API.
*/
export async function executeMemuMemory(query: string): Promise<ToolResult> {
try {
const memuConfig = await getMemuConfig()
if (!hasRemoteConfig(memuConfig)) {
return { success: false, error: 'memU is not fully configured' }
}

const response = await fetch(`${memuConfig.baseUrl}/api/v3/memory/retrieve`, {
method: 'POST',
headers: {
Expand All @@ -40,7 +53,22 @@ export async function executeMemuMemory(query: string): Promise<ToolResult> {
query
})
})
const result = await response.json()

let result: unknown = null
try {
result = await response.json()
} catch {
result = null
}

if (!response.ok) {
const message =
typeof result === 'object' && result && 'message' in result && typeof (result as { message?: unknown }).message === 'string'
? (result as { message: string }).message
: `memU retrieve failed with HTTP ${response.status}`
return { success: false, error: message }
}

return { success: true, data: result }
} catch (error) {
return { success: false, error: error instanceof Error ? error.message : String(error) }
Expand Down