From ca3fc8fc2f136d9964cf49d4c8b80574ab08f6c3 Mon Sep 17 00:00:00 2001 From: Hao19911125 <99091644+Hao19911125@users.noreply.github.com> Date: Sun, 5 Apr 2026 15:24:01 +0800 Subject: [PATCH] Harden legacy extraction and recall JSON parsing --- extractor.js | 135 ++++++++++++++++++++++++++++++++++++- retriever.js | 55 +++++++++++++-- tests/p0-regressions.mjs | 57 ++++++++++++++++ tests/retrieval-config.mjs | 48 +++++++++++++ 4 files changed, 287 insertions(+), 8 deletions(-) diff --git a/extractor.js b/extractor.js index 3a4f517..e1e064e 100644 --- a/extractor.js +++ b/extractor.js @@ -67,6 +67,136 @@ function throwIfAborted(signal) { } } +function resolveExtractionOperations(result, schema = []) { + const candidates = [ + { source: "operations", value: result?.operations }, + { source: "nodes", value: result?.nodes }, + { source: "memories", value: result?.memories }, + { source: "root", value: result }, + ]; + + for (const candidate of candidates) { + if (!Array.isArray(candidate.value)) { + continue; + } + + const normalized = normalizeExtractionOperations(candidate.value, schema); + if (normalized?.legacyCount > 0) { + console.info("[ST-BME] 兼容旧版扁平提取输出", { + source: candidate.source, + normalizedCount: normalized.legacyCount, + totalCount: normalized.operations.length, + }); + } + return normalized.operations; + } + + return null; +} + +function normalizeExtractionOperations(operations, schema = []) { + if (!Array.isArray(operations)) { + return null; + } + + let legacyCount = 0; + const normalizedOperations = operations.map((operation) => { + const normalized = normalizeExtractionOperation(operation, schema); + if (normalized?.__legacyCompat) { + legacyCount += 1; + delete normalized.__legacyCompat; + } + return normalized; + }); + + return { + operations: normalizedOperations, + legacyCount, + }; +} + +function normalizeExtractionOperation(operation, schema = []) { + if (!operation || typeof operation !== "object") { + return operation; + } + + const normalized = { ...operation }; + const normalizedAction = normalizeOperationAction(normalized); + if (normalizedAction) { + normalized.action = normalizedAction; + } + + const typeDef = schema.find((entry) => entry?.id === normalized.type); + const normalizedFields = extractOperationFields(normalized, typeDef); + if ( + normalized.action === "create" || + normalized.action === "update" || + (!normalized.action && Object.keys(normalizedFields).length > 0) + ) { + normalized.fields = normalizedFields; + } + + if ( + !normalized.action && + typeDef && + !normalized.nodeId && + Object.keys(normalizedFields).length > 0 + ) { + normalized.action = "create"; + normalized.__legacyCompat = true; + } + + if ( + (normalized.action === "update" || normalized.action === "delete") && + !normalized.nodeId && + typeof normalized.id === "string" && + normalized.id.trim() + ) { + normalized.nodeId = normalized.id.trim(); + } + + if ( + normalized.action === "create" && + !normalized.ref && + typeof normalized.id === "string" && + normalized.id.trim() + ) { + normalized.ref = normalized.id.trim(); + } + + return normalized; +} + +function normalizeOperationAction(operation = {}) { + const candidate = operation.action ?? operation.op ?? operation.operation; + return typeof candidate === "string" && candidate.trim() + ? candidate.trim() + : ""; +} + +function extractOperationFields(operation = {}, typeDef = null) { + const fields = { + ...(operation.fields && typeof operation.fields === "object" + ? operation.fields + : {}), + }; + + const columnNames = Array.isArray(typeDef?.columns) + ? typeDef.columns + .map((column) => String(column?.name || "").trim()) + .filter(Boolean) + : []; + + for (const fieldName of columnNames) { + if (fields[fieldName] !== undefined || operation[fieldName] === undefined) { + continue; + } + fields[fieldName] = operation[fieldName]; + } + + return fields; +} + /** * 对未处理的对话楼层执行记忆提取 * @@ -203,7 +333,8 @@ export async function extractMemories({ }); throwIfAborted(signal); - if (!result || !Array.isArray(result.operations)) { + const operations = resolveExtractionOperations(result, schema); + if (!result || !Array.isArray(operations)) { console.warn("[ST-BME] 提取 LLM 未返回有效操作"); return { success: false, @@ -222,7 +353,7 @@ export async function extractMemories({ const refMap = new Map(); const operationErrors = []; - for (const op of result.operations) { + for (const op of operations) { try { switch (op.action) { case "create": { diff --git a/retriever.js b/retriever.js index 7290c57..471f21e 100644 --- a/retriever.js +++ b/retriever.js @@ -97,6 +97,44 @@ function buildRecallFallbackReason(llmResult) { } } +function resolveRecallSelectedIds(result) { + if (Array.isArray(result)) { + return result; + } + + const visited = new Set(); + const queue = [{ value: result, depth: 0 }]; + while (queue.length > 0) { + const current = queue.shift(); + const value = current?.value; + const depth = Number(current?.depth) || 0; + if (!value || typeof value !== "object" || visited.has(value) || depth > 1) { + continue; + } + visited.add(value); + + const directCandidates = [ + value.selected_ids, + value.selectedIds, + value.node_ids, + value.nodeIds, + value.ids, + ]; + for (const candidate of directCandidates) { + if (Array.isArray(candidate)) { + return candidate; + } + } + + queue.push({ value: value.data, depth: depth + 1 }); + queue.push({ value: value.result, depth: depth + 1 }); + queue.push({ value: value.payload, depth: depth + 1 }); + queue.push({ value: value.output, depth: depth + 1 }); + } + + return null; +} + function isAbortError(error) { return error?.name === "AbortError"; } @@ -1515,21 +1553,22 @@ async function llmRecall( returnFailureDetails: true, }); const result = llmResult?.ok ? llmResult.data : null; + const selectedIds = resolveRecallSelectedIds(result); - if (result?.selected_ids && Array.isArray(result.selected_ids)) { + if (Array.isArray(selectedIds)) { // 校验 ID 有效性 const validIds = uniqueNodeIds( - result.selected_ids.filter((id) => + selectedIds.filter((id) => candidates.some((c) => c.nodeId === id), ), ).slice(0, maxNodes); - if (validIds.length > 0 || result.selected_ids.length === 0) { + if (validIds.length > 0 || selectedIds.length === 0) { return { selectedNodeIds: validIds, status: "llm", reason: - validIds.length < result.selected_ids.length + validIds.length < selectedIds.length ? "LLM 返回了部分无效或超限 ID,已自动裁剪" : "LLM 精排完成", }; @@ -1538,7 +1577,7 @@ async function llmRecall( // LLM 失败时回退到纯评分排序 const fallbackReason = llmResult?.ok - ? Array.isArray(result?.selected_ids) + ? Array.isArray(selectedIds) ? "LLM 返回的候选 ID 无效,已回退到评分排序" : "LLM 返回了无法识别的 JSON 结构,已回退到评分排序" : buildRecallFallbackReason(llmResult); @@ -1546,7 +1585,11 @@ async function llmRecall( selectedNodeIds: candidates.slice(0, maxNodes).map((c) => c.nodeId), status: "fallback", reason: fallbackReason, - fallbackType: llmResult?.ok ? "invalid-candidate" : llmResult?.errorType || "unknown", + fallbackType: llmResult?.ok + ? Array.isArray(selectedIds) + ? "invalid-candidate" + : "invalid-structure" + : llmResult?.errorType || "unknown", }; } diff --git a/tests/p0-regressions.mjs b/tests/p0-regressions.mjs index 6349604..28e3902 100644 --- a/tests/p0-regressions.mjs +++ b/tests/p0-regressions.mjs @@ -2105,6 +2105,62 @@ async function testExtractorFailsOnUnknownOperation() { } } +async function testExtractorSupportsLegacyFlatNodeOperations() { + const graph = createEmptyGraph(); + const restoreOverrides = pushTestOverrides({ + llm: { + async callLLMForJSON() { + return { + operations: [ + { + type: "event", + id: "evt-legacy", + title: "夜间喂食", + summary: "角色完成了一次深夜喂食。", + participants: "悟岳, 访客", + status: "resolved", + importance: 6, + }, + { + type: "character", + id: "char-legacy", + name: "悟岳", + state: "放松下来", + }, + ], + }; + }, + }, + }); + + try { + const result = await extractMemories({ + graph, + messages: [{ seq: 7, role: "assistant", content: "测试旧版扁平提取输出" }], + startSeq: 7, + endSeq: 7, + schema, + embeddingConfig: null, + settings: {}, + }); + + assert.equal(result.success, true); + assert.equal(result.newNodes, 2); + assert.equal(graph.lastProcessedSeq, 7); + + const eventNode = graph.nodes.find((node) => node.type === "event"); + const characterNode = graph.nodes.find((node) => node.type === "character"); + assert.ok(eventNode); + assert.ok(characterNode); + assert.equal(eventNode.fields?.title, "夜间喂食"); + assert.equal(eventNode.fields?.summary, "角色完成了一次深夜喂食。"); + assert.equal(characterNode.fields?.name, "悟岳"); + assert.equal(characterNode.fields?.state, "放松下来"); + } finally { + restoreOverrides(); + } +} + async function testConsolidatorMergeUpdatesSeqRange() { const graph = createEmptyGraph(); const target = createNode({ @@ -5023,6 +5079,7 @@ async function testLlmOutputRegexCleansResponseBeforeJsonParse() { await testCompressorMigratesEdgesToCompressedNode(); await testVectorIndexKeepsDirtyOnDirectPartialEmbeddingFailure(); await testExtractorFailsOnUnknownOperation(); +await testExtractorSupportsLegacyFlatNodeOperations(); await testConsolidatorMergeUpdatesSeqRange(); await testConsolidatorMergeFallbackKeepsNodeWhenTargetMissing(); await testBatchJournalVectorDeltaCapturesRecoveryFields(); diff --git a/tests/retrieval-config.mjs b/tests/retrieval-config.mjs index ec36163..6a0f9b5 100644 --- a/tests/retrieval-config.mjs +++ b/tests/retrieval-config.mjs @@ -329,6 +329,54 @@ assert.equal(state.llmOptions[0].returnFailureDetails, true); assert.equal(state.llmOptions[0].maxRetries, 2); assert.equal(state.llmOptions[0].maxCompletionTokens, 512); +state.vectorCalls.length = 0; +state.diffusionCalls.length = 0; +state.llmCalls.length = 0; +state.llmOptions.length = 0; +state.llmCandidateCount = 0; +state.llmResponse = { selectedIds: ["rule-1"] }; +const llmCamelCaseResult = await retrieve({ + graph, + userMessage: "换个 JSON 键名也应该兼容", + recentMessages: [], + embeddingConfig: {}, + schema, + options: { + topK: 4, + maxRecallNodes: 2, + enableVectorPrefilter: true, + enableGraphDiffusion: false, + enableLLMRecall: true, + llmCandidatePool: 2, + }, +}); +assert.deepEqual(Array.from(llmCamelCaseResult.selectedNodeIds), ["rule-1"]); +assert.equal(llmCamelCaseResult.meta.retrieval.llm.status, "llm"); + +state.vectorCalls.length = 0; +state.diffusionCalls.length = 0; +state.llmCalls.length = 0; +state.llmOptions.length = 0; +state.llmCandidateCount = 0; +state.llmResponse = { data: { selected_ids: ["rule-2"] } }; +const llmNestedResult = await retrieve({ + graph, + userMessage: "嵌套 JSON 结构也应该兼容", + recentMessages: [], + embeddingConfig: {}, + schema, + options: { + topK: 4, + maxRecallNodes: 2, + enableVectorPrefilter: true, + enableGraphDiffusion: false, + enableLLMRecall: true, + llmCandidatePool: 2, + }, +}); +assert.deepEqual(Array.from(llmNestedResult.selectedNodeIds), ["rule-2"]); +assert.equal(llmNestedResult.meta.retrieval.llm.status, "llm"); + state.vectorCalls.length = 0; state.diffusionCalls.length = 0; state.llmCalls.length = 0;