From dc5051f2efe1cb44664c3b89ac38ae50e389722f Mon Sep 17 00:00:00 2001 From: Youzini-afk <13153778771cx@gmail.com> Date: Sun, 12 Apr 2026 14:59:22 +0800 Subject: [PATCH] feat: shared ranking core + prompt node references; recall reuses shared core for base query/vector/diffusion; remove retriever-local duplicate helpers; add regression tests --- maintenance/extractor.js | 113 +++- prompting/prompt-node-references.js | 93 +++ retrieval/retriever.js | 495 +++----------- retrieval/shared-ranking.js | 752 +++++++++++++++++++++ tests/extractor-phase3-layered-context.mjs | 93 ++- tests/prompt-node-references.mjs | 54 ++ tests/retrieval-config.mjs | 482 +++++++++++++ tests/shared-ranking.mjs | 175 +++++ 8 files changed, 1855 insertions(+), 402 deletions(-) create mode 100644 prompting/prompt-node-references.js create mode 100644 retrieval/shared-ranking.js create mode 100644 tests/prompt-node-references.mjs create mode 100644 tests/shared-ranking.mjs diff --git a/maintenance/extractor.js b/maintenance/extractor.js index c5bcfe2..6a5ccd0 100644 --- a/maintenance/extractor.js +++ b/maintenance/extractor.js @@ -16,7 +16,7 @@ import { updateNode, } from "../graph/graph.js"; import { callLLMForJSON } from "../llm/llm.js"; -import { ensureEventTitle, getNodeDisplayName } from "../graph/node-labels.js"; +import { ensureEventTitle } from "../graph/node-labels.js"; import { normalizeMemoryScope, isObjectiveScope, @@ -41,8 +41,10 @@ import { buildTaskLlmPayload, buildTaskPrompt, } from "../prompting/prompt-builder.js"; +import { createPromptNodeReferenceMap } from "../prompting/prompt-node-references.js"; import { RELATION_TYPES } from "../graph/schema.js"; import { applyTaskRegex } from "../prompting/task-regex.js"; +import { rankNodesForTaskContext } from "../retrieval/shared-ranking.js"; import { getSTContextForPrompt, getSTContextSnapshot } from "../host/st-context.js"; import { buildExtractionInputContext } from "./extraction-context.js"; import { @@ -148,6 +150,62 @@ function resolveExtractPromptStructuredMode(settings) { return "both"; } +function formatExtractRankingMessage(message = {}) { + const role = String(message?.role || "assistant").trim().toLowerCase() === "user" + ? "user" + : "assistant"; + const content = String(message?.content || "").trim(); + if (!content) return ""; + return `[${role}]: ${content}`; +} + +function buildExtractRankingQueryText(messages = []) { + const normalizedMessages = Array.isArray(messages) ? messages : []; + const targetLines = normalizedMessages + .filter((message) => message?.isContextOnly !== true) + .map((message) => formatExtractRankingMessage(message)) + .filter(Boolean); + if (targetLines.length > 0) { + return targetLines.join("\n"); + } + return normalizedMessages + .map((message) => formatExtractRankingMessage(message)) + .filter(Boolean) + .join("\n"); +} + +function buildExtractRelevantNodeReferenceMap(scoredNodes = [], schema = [], maxCount = 6) { + const typeLabelById = new Map( + (Array.isArray(schema) ? schema : []).map((typeDef) => [ + String(typeDef?.id || "").trim(), + String(typeDef?.label || typeDef?.id || "").trim(), + ]), + ); + const relevantNodes = (Array.isArray(scoredNodes) ? scoredNodes : []) + .filter((entry) => + entry?.node && + !entry.node.archived && + ((Number(entry?.vectorScore) || 0) > 0 || + (Number(entry?.graphScore) || 0) > 0 || + (Number(entry?.lexicalScore) || 0) > 0), + ) + .slice(0, Math.max(1, maxCount)); + return createPromptNodeReferenceMap(relevantNodes, { + prefix: "G", + maxLength: 28, + buildMeta: ({ entry, node }) => ({ + typeLabel: + typeLabelById.get(String(node?.type || "").trim()) || + String(node?.type || "节点").trim() || + "节点", + score: + Math.round( + (Number(entry?.weightedScore ?? entry?.finalScore) || 0) * 1000, + ) / 1000, + }), + }); +} + function isAbortError(error) { return error?.name === "AbortError"; } @@ -900,8 +958,31 @@ export async function extractMemories({ ? dialogueText : structuredMessages; + const extractGraphRankingQuery = buildExtractRankingQueryText(structuredMessages); + const extractGraphRanking = + graph?.nodes?.some((node) => !node?.archived) && extractGraphRankingQuery + ? await rankNodesForTaskContext({ + graph, + userMessage: extractGraphRankingQuery, + recentMessages: [], + embeddingConfig, + signal, + options: { + topK: 12, + diffusionTopK: 48, + enableContextQueryBlend: false, + enableMultiIntent: true, + maxTextLength: 1200, + }, + }) + : null; + const extractGraphRelevantNodes = buildExtractRelevantNodeReferenceMap( + extractGraphRanking?.scoredNodes, + schema, + ); + // 构建当前图概览(让 LLM 知道已有哪些节点,避免重复) - const graphOverview = buildGraphOverview(graph, schema); + const graphOverview = buildGraphOverview(graph, schema, extractGraphRelevantNodes); // 构建 Schema 描述 const schemaDescription = buildSchemaDescription(schema); @@ -924,6 +1005,14 @@ export async function extractMemories({ `storyTimeContext=${storyTimeContext ? "present" : "none"}, ` + `worldbookMode=${String(settings?.extractWorldbookMode || "active")}`, ); + if (extractGraphRanking) { + debugLog( + `[ST-BME][extract-graph] relevantNodes=${extractGraphRelevantNodes.references.length}, ` + + `vectorMergedHits=${Number(extractGraphRanking?.diagnostics?.vectorMergedHits || 0)}, ` + + `diffusionHits=${Number(extractGraphRanking?.diagnostics?.diffusionHits || 0)}, ` + + `lexicalBoostedNodes=${Number(extractGraphRanking?.diagnostics?.lexicalBoostedNodes || 0)}`, + ); + } const extractWorldbookMode = String(settings?.extractWorldbookMode || "active").trim().toLowerCase(); const promptBuild = await buildTaskPrompt(settings, "extract", { @@ -1753,21 +1842,31 @@ async function generateNodeEmbeddings(graph, embeddingConfig, signal) { /** * 构建图谱概览文本(给 LLM 看) */ -function buildGraphOverview(graph, schema) { +function buildGraphOverview(graph, schema, relevantReferenceMap = null) { const activeNodes = graph.nodes .filter((n) => !n.archived) .sort((a, b) => (a.seq || 0) - (b.seq || 0)); if (activeNodes.length === 0) return ""; const lines = []; + lines.push("### 图谱节点统计"); for (const typeDef of schema) { const nodesOfType = activeNodes.filter((n) => n.type === typeDef.id); if (nodesOfType.length === 0) continue; - lines.push(`### ${typeDef.label} (${nodesOfType.length} 个节点)`); - for (const node of nodesOfType.slice(-10)) { - // 只展示最近 10 个 - lines.push(` - [${node.id}] ${getNodeDisplayName(node)}`); + lines.push(` - ${typeDef.label}: ${nodesOfType.length}`); + } + + const references = Array.isArray(relevantReferenceMap?.references) + ? relevantReferenceMap.references + : []; + if (references.length > 0) { + lines.push("", "### 与当前提取片段最相关的既有节点"); + for (const reference of references) { + const typeLabel = String(reference?.meta?.typeLabel || reference?.meta?.type || "节点").trim() || "节点"; + const label = String(reference?.meta?.label || "—").trim() || "—"; + const score = Number(reference?.meta?.score || 0).toFixed(3); + lines.push(` - [${reference.key}|${typeLabel}] ${label} (score=${score})`); } } diff --git a/prompting/prompt-node-references.js b/prompting/prompt-node-references.js new file mode 100644 index 0000000..26961fd --- /dev/null +++ b/prompting/prompt-node-references.js @@ -0,0 +1,93 @@ +import { truncateNodeLabel } from "../graph/node-labels.js"; + +function normalizePromptNodeText(value) { + return String(value ?? "") + .replace(/\s+/g, " ") + .trim(); +} + +function resolvePromptNode(value = {}) { + if (value?.node && typeof value.node === "object") { + return value.node; + } + return value && typeof value === "object" ? value : {}; +} + +export function resolvePromptNodeId(value = {}) { + const node = resolvePromptNode(value); + return String(value?.nodeId || node?.id || "").trim(); +} + +export function getPromptNodeLabel(value = {}, { maxLength = 32 } = {}) { + const node = resolvePromptNode(value); + const fallbackId = typeof node?.id === "string" ? node.id.slice(0, 8) : ""; + const rawLabel = normalizePromptNodeText( + node?.fields?.title || + node?.fields?.name || + node?.fields?.summary || + node?.fields?.insight || + node?.fields?.belief || + node?.name || + fallbackId || + "—", + ); + return truncateNodeLabel(rawLabel || "—", maxLength); +} + +export function createPromptNodeReferenceMap( + entries = [], + { + prefix = "N", + maxLength = 32, + buildMeta = null, + } = {}, +) { + const keyToNodeId = {}; + const keyToMeta = {}; + const nodeIdToKey = {}; + const references = []; + + for (const [index, entry] of (Array.isArray(entries) ? entries : []).entries()) { + const node = resolvePromptNode(entry); + const nodeId = resolvePromptNodeId(entry); + if (!nodeId || nodeIdToKey[nodeId]) { + continue; + } + + const key = `${String(prefix || "N").trim() || "N"}${references.length + 1}`; + const label = getPromptNodeLabel(node, { maxLength }); + const extraMeta = typeof buildMeta === "function" + ? buildMeta({ + entry, + node, + nodeId, + key, + index, + label, + }) + : {}; + + keyToNodeId[key] = nodeId; + nodeIdToKey[nodeId] = key; + keyToMeta[key] = { + nodeId, + type: String(node?.type || ""), + label, + ...(extraMeta && typeof extraMeta === "object" ? extraMeta : {}), + }; + references.push({ + key, + nodeId, + node, + meta: keyToMeta[key], + }); + } + + return { + prefix: String(prefix || "N").trim() || "N", + references, + keyToNodeId, + keyToMeta, + nodeIdToKey, + }; +} diff --git a/retrieval/retriever.js b/retrieval/retriever.js index 17d03be..abe37a9 100644 --- a/retrieval/retriever.js +++ b/retrieval/retriever.js @@ -23,9 +23,7 @@ import { collectSupplementalAnchorNodeIds, createCooccurrenceIndex, isEligibleAnchorNode, - mergeVectorResults, runResidualRecall, - splitIntentSegments, } from "./retrieval-enhancer.js"; import { MEMORY_SCOPE_BUCKETS, @@ -36,6 +34,7 @@ import { normalizeMemoryScope, resolveScopeBucketWeight, } from "../graph/memory-scope.js"; +import { rankNodesForTaskContext } from "./shared-ranking.js"; import { computeKnowledgeGateForNode, listKnowledgeOwners, @@ -54,8 +53,8 @@ import { } from "../graph/story-timeline.js"; import { getActiveSummaryEntries } from "../graph/summary-state.js"; import { applyTaskRegex } from "../prompting/task-regex.js"; +import { createPromptNodeReferenceMap } from "../prompting/prompt-node-references.js"; import { getSTContextForPrompt } from "../host/st-context.js"; -import { findSimilarNodesByText, validateVectorConfig } from "../vector/vector-index.js"; function createAbortError(message = "操作已终止") { const error = new Error(message); @@ -241,14 +240,6 @@ function normalizeQueryText(value, maxLength = 400) { return normalized.slice(0, Math.max(1, maxLength)); } -function createTextPreview(text, maxLength = 120) { - const normalized = normalizeQueryText(text, maxLength + 4); - if (!normalized) return ""; - return normalized.length > maxLength - ? `${normalized.slice(0, maxLength)}...` - : normalized; -} - function normalizeRecallSelectionList(values = [], maxLength = 64) { const normalized = []; const seen = new Set(); @@ -262,262 +253,23 @@ function normalizeRecallSelectionList(values = [], maxLength = 64) { return normalized; } -function getRecallCandidateLabel(node = {}) { - return String( - node?.fields?.title || - node?.fields?.name || - node?.fields?.summary || - node?.fields?.insight || - node?.fields?.belief || - node?.id || - "", - ).trim(); -} - function createRecallCandidateKeyMaps(candidates = []) { - const candidateKeyToNodeId = {}; - const candidateKeyToCandidateMeta = {}; - const nodeIdToCandidateKey = {}; - - for (const [index, candidate] of (Array.isArray(candidates) ? candidates : []).entries()) { - const node = candidate?.node || {}; - const nodeId = String(candidate?.nodeId || node?.id || "").trim(); - if (!nodeId) continue; - const candidateKey = `R${index + 1}`; - candidateKeyToNodeId[candidateKey] = nodeId; - nodeIdToCandidateKey[nodeId] = candidateKey; - candidateKeyToCandidateMeta[candidateKey] = { - nodeId, - type: String(node?.type || ""), - label: getRecallCandidateLabel(node), - scopeBucket: String(candidate?.scopeBucket || ""), - temporalBucket: String(candidate?.temporalBucket || ""), + const referenceMap = createPromptNodeReferenceMap(candidates, { + prefix: "R", + maxLength: 80, + buildMeta: ({ entry }) => ({ + scopeBucket: String(entry?.scopeBucket || ""), + temporalBucket: String(entry?.temporalBucket || ""), score: Math.round( - (Number(candidate?.weightedScore ?? candidate?.finalScore) || 0) * 1000, + (Number(entry?.weightedScore ?? entry?.finalScore) || 0) * 1000, ) / 1000, - }; - } - + }), + }); return { - candidateKeyToNodeId, - candidateKeyToCandidateMeta, - nodeIdToCandidateKey, - }; -} - -function roundBlendWeight(value) { - return Math.round((Number(value) || 0) * 1000) / 1000; -} - -function uniqueStrings(values = [], maxLength = 400) { - const result = []; - const seen = new Set(); - - for (const value of values) { - const text = normalizeQueryText(value, maxLength); - const key = text.toLowerCase(); - if (!text || seen.has(key)) continue; - seen.add(key); - result.push(text); - } - - return result; -} - -function parseRecallContextLine(line = "") { - const raw = String(line ?? "").trim(); - if (!raw) return null; - - const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i); - if (bracketMatch) { - const role = String(bracketMatch[1] || "").toLowerCase(); - const text = normalizeQueryText(bracketMatch[2] || ""); - return text ? { role, text } : null; - } - - const plainMatch = raw.match( - /^(user|assistant|用户|助手|ai)\s*[::]\s*([\s\S]*)$/i, - ); - if (!plainMatch) return null; - - const roleToken = String(plainMatch[1] || "").toLowerCase(); - const role = - roleToken === "assistant" || roleToken === "助手" || roleToken === "ai" - ? "assistant" - : "user"; - const text = normalizeQueryText(plainMatch[2] || ""); - return text ? { role, text } : null; -} - -function buildContextQueryBlend( - userMessage, - recentMessages = [], - { - enabled = true, - assistantWeight = 0.2, - previousUserWeight = 0.1, - maxTextLength = 400, - } = {}, -) { - const currentText = normalizeQueryText(userMessage, maxTextLength); - const normalizedAssistantWeight = clampRange(assistantWeight, 0.2, 0, 1); - const normalizedPreviousUserWeight = clampRange( - previousUserWeight, - 0.1, - 0, - 1, - ); - const currentWeight = Math.max( - 0, - 1 - normalizedAssistantWeight - normalizedPreviousUserWeight, - ); - - let assistantText = ""; - let previousUserText = ""; - const parsedMessages = Array.isArray(recentMessages) - ? recentMessages.map((line) => parseRecallContextLine(line)).filter(Boolean) - : []; - - for (let index = parsedMessages.length - 1; index >= 0; index--) { - const item = parsedMessages[index]; - if (!assistantText && item.role === "assistant") { - assistantText = normalizeQueryText(item.text, maxTextLength); - } - if ( - !previousUserText && - item.role === "user" && - normalizeQueryText(item.text, maxTextLength).toLowerCase() !== - currentText.toLowerCase() - ) { - previousUserText = normalizeQueryText(item.text, maxTextLength); - } - if (assistantText && previousUserText) break; - } - - const rawParts = [ - { - kind: "currentUser", - label: "当前用户消息", - text: currentText, - weight: enabled ? currentWeight : 1, - }, - ]; - - if (enabled && assistantText) { - rawParts.push({ - kind: "assistantContext", - label: "最近 assistant 回复", - text: assistantText, - weight: normalizedAssistantWeight, - }); - } - - if (enabled && previousUserText) { - rawParts.push({ - kind: "previousUser", - label: "上一条 user 消息", - text: previousUserText, - weight: normalizedPreviousUserWeight, - }); - } - - const dedupedParts = []; - const seen = new Set(); - for (const part of rawParts) { - const text = normalizeQueryText(part.text, maxTextLength); - const key = text.toLowerCase(); - if (!text || seen.has(key)) continue; - seen.add(key); - dedupedParts.push({ - ...part, - text, - }); - } - - if (dedupedParts.length === 0) { - return { - active: false, - parts: [], - currentText: "", - assistantText: "", - previousUserText: "", - combinedText: "", - }; - } - - const totalWeight = dedupedParts.reduce( - (sum, part) => sum + Math.max(0, Number(part.weight) || 0), - 0, - ); - const normalizedParts = dedupedParts.map((part) => ({ - ...part, - weight: - totalWeight > 0 - ? roundBlendWeight((Math.max(0, Number(part.weight) || 0) || 0) / totalWeight) - : roundBlendWeight(1 / dedupedParts.length), - })); - const combinedText = - normalizedParts.length <= 1 - ? normalizedParts[0]?.text || "" - : normalizedParts - .map((part) => `${part.label}:\n${part.text}`) - .join("\n\n"); - - return { - active: enabled && normalizedParts.length > 1, - parts: normalizedParts, - currentText: currentText || normalizedParts[0]?.text || "", - assistantText, - previousUserText, - combinedText, - }; -} - -function buildVectorQueryPlan( - blendPlan, - { enableMultiIntent = true, maxSegments = 4 } = {}, -) { - const plan = []; - let currentSegments = []; - - for (const part of blendPlan?.parts || []) { - let queries = [part.text]; - if (part.kind === "currentUser" && enableMultiIntent) { - currentSegments = splitIntentSegments(part.text, { maxSegments }); - queries = uniqueStrings([ - part.text, - ...currentSegments.filter((item) => item !== part.text), - ]); - } else { - queries = uniqueStrings([part.text]); - } - - plan.push({ - kind: part.kind, - label: part.label, - weight: part.weight, - queries, - }); - } - - return { - plan, - currentSegments, - }; -} - -function buildLexicalQuerySources( - userMessage, - { enableMultiIntent = true, maxSegments = 4 } = {}, -) { - const currentText = normalizeQueryText(userMessage, 400); - const segments = enableMultiIntent - ? splitIntentSegments(currentText, { maxSegments }) - : []; - return { - sources: uniqueStrings([currentText, ...segments]), - segments, + candidateKeyToNodeId: referenceMap.keyToNodeId, + candidateKeyToCandidateMeta: referenceMap.keyToMeta, + nodeIdToCandidateKey: referenceMap.nodeIdToKey, }; } @@ -722,13 +474,6 @@ function buildVisibilityTopHits(scoredNodes = [], maxCount = 6) { })); } -function scaleVectorResults(results = [], weight = 1) { - return (Array.isArray(results) ? results : []).map((item) => ({ - ...item, - score: (Number(item?.score) || 0) * Math.max(0, Number(weight) || 0), - })); -} - function pickActiveRegion(graph, optionValue = "") { const direct = String(optionValue || "").trim(); if (direct) return direct; @@ -1462,7 +1207,6 @@ export async function retrieve({ normalizedMaxRecallNodes, llmCandidatePool, ); - const vectorValidation = validateVectorConfig(embeddingConfig); const retrievalMeta = createRetrievalMeta(enableLLMRecall); retrievalMeta.activeRegion = activeRegion; retrievalMeta.activeRegionSource = activeRegionContext.source || ""; @@ -1490,29 +1234,6 @@ export async function retrieve({ retrievalMeta.knowledgeGateMode = enableCognitiveMemory ? "anchored-soft-visibility" : "disabled"; - const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, { - enabled: enableContextQueryBlend, - assistantWeight: contextAssistantWeight, - previousUserWeight: contextPreviousUserWeight, - }); - retrievalMeta.queryBlendActive = contextQueryBlend.active; - retrievalMeta.queryBlendParts = (contextQueryBlend.parts || []).map((part) => ({ - kind: part.kind, - label: part.label, - weight: part.weight, - text: createTextPreview(part.text), - length: part.text.length, - })); - retrievalMeta.queryBlendWeights = Object.fromEntries( - (contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]), - ); - const lexicalQuery = buildLexicalQuerySources( - contextQueryBlend.currentText || userMessage, - { - enableMultiIntent, - maxSegments: multiIntentMaxSegments, - }, - ); debugLog( `[ST-BME] 检索开始: ${nodeCount} 个活跃节点${enableVisibility ? " (认知边界已启用)" : ""}`, ); @@ -1567,49 +1288,85 @@ export async function retrieve({ }, }); } - - const vectorStartedAt = nowMs(); - if (enableVectorPrefilter && vectorValidation.valid) { - debugLog("[ST-BME] 第1层: 向量预筛"); - const queryPlan = buildVectorQueryPlan(contextQueryBlend, { + const sharedRanking = await rankNodesForTaskContext({ + graph, + userMessage, + recentMessages, + embeddingConfig, + signal, + options: { + topK: normalizedTopK, + diffusionTopK: normalizedDiffusionTopK, + enableVectorPrefilter, + enableGraphDiffusion, + enableContextQueryBlend, enableMultiIntent, - maxSegments: multiIntentMaxSegments, - }); - const groups = []; - - retrievalMeta.segmentsUsed = queryPlan.currentSegments; - for (const part of queryPlan.plan) { - for (const queryText of part.queries) { - const results = await vectorPreFilter( - graph, - queryText, - activeNodes, - embeddingConfig, - normalizedTopK, - signal, - ); - groups.push(scaleVectorResults(results, part.weight || 1)); - } - } - - const merged = mergeVectorResults( - groups, - Math.max(normalizedTopK * 2, 24), - ); - retrievalMeta.vectorHits = merged.rawHitCount; - retrievalMeta.vectorMergedHits = merged.results.length; - vectorResults = merged.results; - } else if (enableVectorPrefilter) { - pushSkipReason(retrievalMeta, "vector-config-invalid"); - } - retrievalMeta.timings.vector = roundMs(nowMs() - vectorStartedAt); - - exactEntityAnchors.push( - ...extractEntityAnchors( - contextQueryBlend.currentText || userMessage, + multiIntentMaxSegments, + contextAssistantWeight, + contextPreviousUserWeight, + teleportAlpha, + enableTemporalLinks, + temporalLinkStrength, + enableLexicalBoost, + lexicalWeight, + weights, activeNodes, - ), + }, + }); + const contextQueryBlend = sharedRanking.contextQueryBlend; + const lexicalQuery = sharedRanking.lexicalQuery; + retrievalMeta.queryBlendActive = Boolean( + sharedRanking?.diagnostics?.queryBlendActive, ); + retrievalMeta.queryBlendParts = Array.isArray( + sharedRanking?.diagnostics?.queryBlendParts, + ) + ? [...sharedRanking.diagnostics.queryBlendParts] + : []; + retrievalMeta.queryBlendWeights = { + ...(sharedRanking?.diagnostics?.queryBlendWeights || {}), + }; + retrievalMeta.segmentsUsed = Array.isArray(sharedRanking?.diagnostics?.segmentsUsed) + ? [...sharedRanking.diagnostics.segmentsUsed] + : []; + retrievalMeta.vectorHits = Number(sharedRanking?.diagnostics?.vectorHits || 0); + retrievalMeta.vectorMergedHits = Number( + sharedRanking?.diagnostics?.vectorMergedHits || 0, + ); + retrievalMeta.seedCount = Number(sharedRanking?.diagnostics?.seedCount || 0); + retrievalMeta.diffusionHits = Number( + sharedRanking?.diagnostics?.diffusionHits || 0, + ); + retrievalMeta.lexicalBoostedNodes = Number( + sharedRanking?.diagnostics?.lexicalBoostedNodes || 0, + ); + retrievalMeta.temporalSyntheticEdgeCount = Number( + sharedRanking?.diagnostics?.temporalSyntheticEdgeCount || 0, + ); + retrievalMeta.teleportAlpha = Number( + sharedRanking?.diagnostics?.teleportAlpha || teleportAlpha, + ); + retrievalMeta.lexicalTopHits = Array.isArray( + sharedRanking?.diagnostics?.lexicalTopHits, + ) + ? [...sharedRanking.diagnostics.lexicalTopHits] + : []; + retrievalMeta.timings.vector = Number( + sharedRanking?.diagnostics?.timings?.vector || 0, + ); + retrievalMeta.timings.diffusion = Number( + sharedRanking?.diagnostics?.timings?.diffusion || 0, + ); + for (const reason of sharedRanking?.diagnostics?.skipReasons || []) { + pushSkipReason(retrievalMeta, reason); + } + vectorResults = Array.isArray(sharedRanking?.vectorResults) + ? [...sharedRanking.vectorResults] + : []; + diffusionResults = Array.isArray(sharedRanking?.diffusionResults) + ? [...sharedRanking.diffusionResults] + : []; + exactEntityAnchors.push(...(sharedRanking?.exactEntityAnchors || [])); supplementalAnchorNodeIds = collectSupplementalAnchorNodeIds( graph, vectorResults, @@ -1650,7 +1407,7 @@ export async function retrieve({ retrievalMeta.timings.residual = roundMs(nowMs() - residualStartedAt); const diffusionStartedAt = nowMs(); - if (enableGraphDiffusion) { + if (enableGraphDiffusion && (enableCrossRecall || residualResult.triggered)) { debugLog("[ST-BME] 第2层: PEDSA 图扩散"); const seeds = [ ...vectorResults.map((v) => ({ id: v.nodeId, energy: v.score })), @@ -1705,9 +1462,11 @@ export async function retrieve({ return node && !node.archived; }); } + retrievalMeta.diffusionHits = diffusionResults.length; + } + if (enableGraphDiffusion && (enableCrossRecall || residualResult.triggered)) { + retrievalMeta.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); } - retrievalMeta.diffusionHits = diffusionResults.length; - retrievalMeta.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); debugLog("[ST-BME] 第3层: 混合评分"); @@ -2259,7 +2018,9 @@ export async function retrieve({ retrievalMeta.timings.total = roundMs(nowMs() - startedAt); return buildResult(graph, selectedNodeIds, schema, { - retrieval: retrievalMeta, + retrieval: { + ...retrievalMeta, + }, scopeContext: { enableScopedMemory, enablePovMemory, @@ -2295,62 +2056,6 @@ export async function retrieve({ }); } -/** - * 向量预筛选 - */ -async function vectorPreFilter( - graph, - userMessage, - activeNodes, - embeddingConfig, - topK, - signal, -) { - try { - return await findSimilarNodesByText( - graph, - userMessage, - embeddingConfig, - topK, - activeNodes, - signal, - ); - } catch (e) { - if (isAbortError(e)) { - throw e; - } - console.error("[ST-BME] 向量预筛失败:", e); - return []; - } -} - -/** - * 实体锚点提取 - * 从用户消息中提取名词/实体,匹配图中的节点名称 - */ -function extractEntityAnchors(userMessage, activeNodes) { - const anchors = []; - const seen = new Set(); - - for (const node of activeNodes) { - const candidates = [node.fields?.name, node.fields?.title] - .filter((value) => typeof value === "string") - .map((value) => value.trim()) - .filter((value) => value.length >= 2); - - for (const candidate of candidates) { - if (!userMessage.includes(candidate)) continue; - const key = `${node.id}:${candidate}`; - if (seen.has(key)) continue; - seen.add(key); - anchors.push({ nodeId: node.id, entity: candidate }); - break; - } - } - - return anchors; -} - function buildResidualBasisNodes( graph, exactEntityAnchors, @@ -2463,7 +2168,9 @@ async function llmRecall( const fieldsStr = Object.entries(node.fields) .map(([k, v]) => `${k}: ${v}`) .join(", "); - const candidateKey = `R${index + 1}`; + const candidateKey = + nodeIdToCandidateKey[String(c?.nodeId || node?.id || "").trim()] || + `R${index + 1}`; return `[${candidateKey}] 类型=${typeLabel}, 作用域=${describeMemoryScope(node.scope)}, 时间=${storyTimeLabel || "未标注"}, 时间桶=${String(c.temporalBucket || STORY_TEMPORAL_BUCKETS.UNDATED)}, 召回桶=${describeScopeBucket(c.scopeBucket)}, 认知=${String(c.knowledgeMode || "unknown")}, 可见性=${(Number(c.knowledgeVisibilityScore) || 0).toFixed(3)}, ${fieldsStr} (评分=${(c.weightedScore ?? c.finalScore).toFixed(3)})`; }) .join("\n"); diff --git a/retrieval/shared-ranking.js b/retrieval/shared-ranking.js new file mode 100644 index 0000000..6e1fba3 --- /dev/null +++ b/retrieval/shared-ranking.js @@ -0,0 +1,752 @@ +import { buildTemporalAdjacencyMap, getActiveNodes, getNode } from "../graph/graph.js"; +import { findSimilarNodesByText, validateVectorConfig } from "../vector/vector-index.js"; +import { hybridScore } from "./dynamics.js"; +import { diffuseAndRank } from "./diffusion.js"; +import { mergeVectorResults, splitIntentSegments } from "./retrieval-enhancer.js"; + +function nowMs() { + if (typeof performance?.now === "function") { + return performance.now(); + } + return Date.now(); +} + +function roundMs(value) { + return Math.round((Number(value) || 0) * 10) / 10; +} + +export function clampPositiveInt(value, fallback, min = 1) { + const parsed = Math.floor(Number(value)); + return Number.isFinite(parsed) && parsed >= min ? parsed : fallback; +} + +export function clampRange(value, fallback, min = 0, max = 1) { + const parsed = Number(value); + if (!Number.isFinite(parsed)) return fallback; + return Math.max(min, Math.min(max, parsed)); +} + +export function normalizeQueryText(value, maxLength = 400) { + const normalized = String(value ?? "") + .replace(/\r\n/g, "\n") + .replace(/\s+/g, " ") + .trim(); + if (!normalized) return ""; + return normalized.slice(0, Math.max(1, maxLength)); +} + +export function createTextPreview(text, maxLength = 120) { + const normalized = normalizeQueryText(text, maxLength + 4); + if (!normalized) return ""; + return normalized.length > maxLength + ? `${normalized.slice(0, maxLength)}...` + : normalized; +} + +function uniqueStrings(values = [], maxLength = 400) { + const result = []; + const seen = new Set(); + + for (const value of values) { + const text = normalizeQueryText(value, maxLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + result.push(text); + } + + return result; +} + +function parseContextLine(line = "") { + const raw = String(line ?? "").trim(); + if (!raw) return null; + + const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i); + if (bracketMatch) { + const role = String(bracketMatch[1] || "").toLowerCase(); + const text = normalizeQueryText(bracketMatch[2] || ""); + return text ? { role, text } : null; + } + + const plainMatch = raw.match(/^(user|assistant|用户|助手|ai)\s*[::]\s*([\s\S]*)$/i); + if (!plainMatch) return null; + + const roleToken = String(plainMatch[1] || "").toLowerCase(); + const role = + roleToken === "assistant" || roleToken === "助手" || roleToken === "ai" + ? "assistant" + : "user"; + const text = normalizeQueryText(plainMatch[2] || ""); + return text ? { role, text } : null; +} + +export function buildContextQueryBlend( + userMessage, + recentMessages = [], + { + enabled = true, + assistantWeight = 0.2, + previousUserWeight = 0.1, + maxTextLength = 400, + } = {}, +) { + const currentText = normalizeQueryText(userMessage, maxTextLength); + const normalizedAssistantWeight = clampRange(assistantWeight, 0.2, 0, 1); + const normalizedPreviousUserWeight = clampRange( + previousUserWeight, + 0.1, + 0, + 1, + ); + const currentWeight = Math.max( + 0, + 1 - normalizedAssistantWeight - normalizedPreviousUserWeight, + ); + + let assistantText = ""; + let previousUserText = ""; + const parsedMessages = Array.isArray(recentMessages) + ? recentMessages.map((line) => parseContextLine(line)).filter(Boolean) + : []; + + for (let index = parsedMessages.length - 1; index >= 0; index -= 1) { + const item = parsedMessages[index]; + if (!assistantText && item.role === "assistant") { + assistantText = normalizeQueryText(item.text, maxTextLength); + } + if ( + !previousUserText && + item.role === "user" && + normalizeQueryText(item.text, maxTextLength).toLowerCase() !== + currentText.toLowerCase() + ) { + previousUserText = normalizeQueryText(item.text, maxTextLength); + } + if (assistantText && previousUserText) break; + } + + const rawParts = [ + { + kind: "currentUser", + label: "当前用户消息", + text: currentText, + weight: enabled ? currentWeight : 1, + }, + ]; + + if (enabled && assistantText) { + rawParts.push({ + kind: "assistantContext", + label: "最近 assistant 回复", + text: assistantText, + weight: normalizedAssistantWeight, + }); + } + + if (enabled && previousUserText) { + rawParts.push({ + kind: "previousUser", + label: "上一条 user 消息", + text: previousUserText, + weight: normalizedPreviousUserWeight, + }); + } + + const dedupedParts = []; + const seen = new Set(); + for (const part of rawParts) { + const text = normalizeQueryText(part.text, maxTextLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + dedupedParts.push({ + ...part, + text, + }); + } + + if (dedupedParts.length === 0) { + return { + active: false, + parts: [], + currentText: "", + assistantText: "", + previousUserText: "", + combinedText: "", + }; + } + + const totalWeight = dedupedParts.reduce( + (sum, part) => sum + Math.max(0, Number(part.weight) || 0), + 0, + ); + const normalizedParts = dedupedParts.map((part) => ({ + ...part, + weight: + totalWeight > 0 + ? Math.round( + ((Math.max(0, Number(part.weight) || 0) || 0) / totalWeight) * 1000, + ) / 1000 + : Math.round((1 / dedupedParts.length) * 1000) / 1000, + })); + const combinedText = + normalizedParts.length <= 1 + ? normalizedParts[0]?.text || "" + : normalizedParts + .map((part) => `${part.label}:\n${part.text}`) + .join("\n\n"); + + return { + active: enabled && normalizedParts.length > 1, + parts: normalizedParts, + currentText: currentText || normalizedParts[0]?.text || "", + assistantText, + previousUserText, + combinedText, + }; +} + +export function buildVectorQueryPlan( + blendPlan, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const plan = []; + let currentSegments = []; + + for (const part of blendPlan?.parts || []) { + let queries = [part.text]; + if (part.kind === "currentUser" && enableMultiIntent) { + currentSegments = splitIntentSegments(part.text, { maxSegments }); + queries = uniqueStrings([ + part.text, + ...currentSegments.filter((item) => item !== part.text), + ]); + } else { + queries = uniqueStrings([part.text]); + } + + plan.push({ + kind: part.kind, + label: part.label, + weight: part.weight, + queries, + }); + } + + return { + plan, + currentSegments, + }; +} + +export function buildLexicalQuerySources( + userMessage, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const currentText = normalizeQueryText(userMessage, 400); + const segments = enableMultiIntent + ? splitIntentSegments(currentText, { maxSegments }) + : []; + return { + sources: uniqueStrings([currentText, ...segments]), + segments, + }; +} + +function normalizeLexicalText(value = "") { + return normalizeQueryText(value, 600).toLowerCase(); +} + +function buildLexicalUnits(text = "") { + const normalized = normalizeLexicalText(text); + if (!normalized) return []; + + const rawTokens = normalized.match(/[a-z0-9]+|[\u4e00-\u9fff]+/g) || []; + const units = []; + + for (const token of rawTokens) { + if (token.length >= 2) { + units.push(token); + } + if (/[\u4e00-\u9fff]/.test(token) && token.length > 2) { + for (let index = 0; index < token.length - 1; index += 1) { + units.push(token.slice(index, index + 2)); + } + } + } + + return [...new Set(units)]; +} + +function computeTokenOverlapScore(sourceUnits = [], targetUnits = []) { + if (!sourceUnits.length || !targetUnits.length) return 0; + const targetSet = new Set(targetUnits); + let overlap = 0; + for (const unit of sourceUnits) { + if (targetSet.has(unit)) { + overlap += 1; + } + } + return overlap / Math.max(1, sourceUnits.length); +} + +function scoreFieldMatch( + fieldText, + querySources = [], + { exact = 1, includes = 0.9, overlap = 0.6 } = {}, +) { + const normalizedField = normalizeLexicalText(fieldText); + if (!normalizedField) return 0; + + const fieldUnits = buildLexicalUnits(normalizedField); + let best = 0; + + for (const sourceText of querySources) { + const normalizedSource = normalizeLexicalText(sourceText); + if (!normalizedSource) continue; + + if (normalizedSource === normalizedField) { + best = Math.max(best, exact); + continue; + } + + if ( + Math.min(normalizedSource.length, normalizedField.length) >= 2 && + (normalizedSource.includes(normalizedField) || + normalizedField.includes(normalizedSource)) + ) { + best = Math.max(best, includes); + } + + const overlapScore = computeTokenOverlapScore( + buildLexicalUnits(normalizedSource), + fieldUnits, + ); + best = Math.max(best, overlapScore * overlap); + } + + return Math.min(1, best); +} + +function collectNodeLexicalTexts(node, fieldNames = []) { + const values = []; + for (const fieldName of fieldNames) { + const value = node?.fields?.[fieldName]; + if (typeof value === "string" && value.trim()) { + values.push(value.trim()); + } else if (Array.isArray(value)) { + for (const item of value) { + if (typeof item === "string" && item.trim()) { + values.push(item.trim()); + } + } + } + } + return values; +} + +export function computeLexicalScore(node, querySources = []) { + if (!node || !Array.isArray(querySources) || querySources.length === 0) { + return 0; + } + + const primaryTexts = collectNodeLexicalTexts(node, ["name", "title"]); + const secondaryTexts = collectNodeLexicalTexts(node, [ + "summary", + "insight", + "state", + "traits", + "participants", + "status", + ]); + const combinedText = [...primaryTexts, ...secondaryTexts].join(" "); + + const primaryScore = primaryTexts.reduce( + (best, value) => + Math.max( + best, + scoreFieldMatch(value, querySources, { + exact: 1, + includes: 0.92, + overlap: 0.72, + }), + ), + 0, + ); + const secondaryScore = secondaryTexts.reduce( + (best, value) => + Math.max( + best, + scoreFieldMatch(value, querySources, { + exact: 0.82, + includes: 0.68, + overlap: 0.52, + }), + ), + 0, + ); + const tokenScore = scoreFieldMatch(combinedText, querySources, { + exact: 0.65, + includes: 0.55, + overlap: 0.45, + }); + + if (primaryScore <= 0 && secondaryScore <= 0 && tokenScore <= 0) { + return 0; + } + + return Math.min( + 1, + Math.max( + primaryScore, + secondaryScore * 0.82, + tokenScore * 0.7, + primaryScore * 0.75 + secondaryScore * 0.35 + tokenScore * 0.2, + ), + ); +} + +export function scaleVectorResults(results = [], weight = 1) { + return (Array.isArray(results) ? results : []).map((item) => ({ + ...item, + score: (Number(item?.score) || 0) * Math.max(0, Number(weight) || 0), + })); +} + +function isAbortError(error) { + return error?.name === "AbortError"; +} + +export async function vectorPreFilter( + graph, + userMessage, + activeNodes, + embeddingConfig, + topK, + signal, +) { + try { + return await findSimilarNodesByText( + graph, + userMessage, + embeddingConfig, + topK, + activeNodes, + signal, + ); + } catch (error) { + if (isAbortError(error)) { + throw error; + } + console.error("[ST-BME] 向量预筛失败:", error); + return []; + } +} + +export function extractEntityAnchors(userMessage, activeNodes) { + const anchors = []; + const seen = new Set(); + + for (const node of Array.isArray(activeNodes) ? activeNodes : []) { + const candidates = [node?.fields?.name, node?.fields?.title] + .filter((value) => typeof value === "string") + .map((value) => value.trim()) + .filter((value) => value.length >= 2); + + for (const candidate of candidates) { + if (!String(userMessage || "").includes(candidate)) continue; + const key = `${node.id}:${candidate}`; + if (seen.has(key)) continue; + seen.add(key); + anchors.push({ nodeId: node.id, entity: candidate }); + break; + } + } + + return anchors; +} + +function buildLexicalTopHits(scoredNodes = [], maxCount = 5) { + return scoredNodes + .filter((item) => (Number(item?.lexicalScore) || 0) > 0) + .sort((a, b) => { + const lexicalDelta = + (Number(b?.lexicalScore) || 0) - (Number(a?.lexicalScore) || 0); + if (lexicalDelta !== 0) return lexicalDelta; + return (Number(b?.finalScore) || 0) - (Number(a?.finalScore) || 0); + }) + .slice(0, Math.max(1, maxCount)) + .map((item) => ({ + nodeId: item.nodeId, + type: item.node?.type || "", + label: + item.node?.fields?.name || + item.node?.fields?.title || + item.node?.fields?.summary || + item.nodeId, + lexicalScore: Math.round((Number(item.lexicalScore) || 0) * 1000) / 1000, + finalScore: Math.round((Number(item.finalScore) || 0) * 1000) / 1000, + })); +} + +export async function rankNodesForTaskContext({ + graph, + userMessage, + recentMessages = [], + embeddingConfig, + signal = undefined, + options = {}, +} = {}) { + const topK = clampPositiveInt(options.topK, 20); + const diffusionTopK = clampPositiveInt(options.diffusionTopK, 100); + const enableVectorPrefilter = options.enableVectorPrefilter ?? true; + const enableGraphDiffusion = options.enableGraphDiffusion ?? true; + const enableContextQueryBlend = options.enableContextQueryBlend ?? true; + const enableMultiIntent = options.enableMultiIntent ?? true; + const multiIntentMaxSegments = clampPositiveInt( + options.multiIntentMaxSegments, + 4, + ); + const contextAssistantWeight = clampRange( + options.contextAssistantWeight, + 0.2, + 0, + 1, + ); + const contextPreviousUserWeight = clampRange( + options.contextPreviousUserWeight, + 0.1, + 0, + 1, + ); + const enableLexicalBoost = options.enableLexicalBoost ?? true; + const lexicalWeight = clampRange(options.lexicalWeight, 0.18, 0, 10); + const teleportAlpha = clampRange(options.teleportAlpha, 0.15); + const enableTemporalLinks = options.enableTemporalLinks ?? true; + const temporalLinkStrength = clampRange( + options.temporalLinkStrength, + 0.2, + 0, + 1, + ); + const maxTextLength = clampPositiveInt(options.maxTextLength, 400, 32); + const weights = options.weights ?? {}; + const activeNodes = Array.isArray(options.activeNodes) + ? options.activeNodes.filter((node) => node && !node.archived) + : getActiveNodes(graph).filter((node) => node && !node.archived); + const vectorValidation = validateVectorConfig(embeddingConfig); + const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, { + enabled: enableContextQueryBlend, + assistantWeight: contextAssistantWeight, + previousUserWeight: contextPreviousUserWeight, + maxTextLength, + }); + const queryPlan = buildVectorQueryPlan(contextQueryBlend, { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }); + const lexicalQuery = buildLexicalQuerySources( + contextQueryBlend.currentText || userMessage, + { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }, + ); + const diagnostics = { + queryBlendActive: contextQueryBlend.active, + queryBlendParts: (contextQueryBlend.parts || []).map((part) => ({ + kind: part.kind, + label: part.label, + weight: part.weight, + text: createTextPreview(part.text), + length: part.text.length, + })), + queryBlendWeights: Object.fromEntries( + (contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]), + ), + segmentsUsed: [...(queryPlan.currentSegments || [])], + vectorValidation, + vectorHits: 0, + vectorMergedHits: 0, + seedCount: 0, + diffusionHits: 0, + temporalSyntheticEdgeCount: 0, + teleportAlpha, + lexicalBoostedNodes: 0, + lexicalTopHits: [], + skipReasons: [], + timings: { + vector: 0, + diffusion: 0, + }, + }; + + if (!graph || activeNodes.length === 0) { + return { + activeNodes, + contextQueryBlend, + queryPlan, + lexicalQuery, + vectorResults: [], + exactEntityAnchors: [], + diffusionResults: [], + scoredNodes: [], + diagnostics, + }; + } + + let vectorResults = []; + const vectorStartedAt = nowMs(); + if (enableVectorPrefilter && vectorValidation.valid) { + const groups = []; + for (const part of queryPlan.plan) { + for (const queryText of part.queries) { + const results = await vectorPreFilter( + graph, + queryText, + activeNodes, + embeddingConfig, + topK, + signal, + ); + groups.push(scaleVectorResults(results, part.weight || 1)); + } + } + + const merged = mergeVectorResults(groups, Math.max(topK * 2, 24)); + diagnostics.vectorHits = merged.rawHitCount; + diagnostics.vectorMergedHits = merged.results.length; + vectorResults = merged.results; + } else if (enableVectorPrefilter) { + diagnostics.skipReasons.push("vector-config-invalid"); + } + diagnostics.timings.vector = roundMs(nowMs() - vectorStartedAt); + + const exactEntityAnchors = extractEntityAnchors( + contextQueryBlend.currentText || userMessage, + activeNodes, + ); + + let diffusionResults = []; + const diffusionStartedAt = nowMs(); + if (enableGraphDiffusion) { + const seeds = [ + ...vectorResults.map((item) => ({ id: item.nodeId, energy: item.score })), + ...exactEntityAnchors.map((item) => ({ id: item.nodeId, energy: 2.0 })), + ]; + const seedMap = new Map(); + for (const seed of seeds) { + const existing = seedMap.get(seed.id) || 0; + if (seed.energy > existing) { + seedMap.set(seed.id, seed.energy); + } + } + const uniqueSeeds = [...seedMap.entries()].map(([id, energy]) => ({ + id, + energy, + })); + diagnostics.seedCount = uniqueSeeds.length; + + if (uniqueSeeds.length > 0) { + const adjacencyMap = buildTemporalAdjacencyMap(graph, { + includeTemporalLinks: enableTemporalLinks, + temporalLinkStrength, + }); + diagnostics.temporalSyntheticEdgeCount = + Number(adjacencyMap?.syntheticEdgeCount) || 0; + diffusionResults = diffuseAndRank(adjacencyMap, uniqueSeeds, { + maxSteps: 2, + decayFactor: 0.6, + topK: diffusionTopK, + teleportAlpha, + }).filter((item) => { + const node = getNode(graph, item.nodeId); + return node && !node.archived; + }); + } + } + diagnostics.diffusionHits = diffusionResults.length; + diagnostics.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); + + const scoreMap = new Map(); + for (const item of vectorResults) { + const entry = scoreMap.get(item.nodeId) || { graphScore: 0, vectorScore: 0 }; + entry.vectorScore = item.score; + scoreMap.set(item.nodeId, entry); + } + for (const item of diffusionResults) { + const entry = scoreMap.get(item.nodeId) || { graphScore: 0, vectorScore: 0 }; + entry.graphScore = item.energy; + scoreMap.set(item.nodeId, entry); + } + if (scoreMap.size === 0) { + for (const node of activeNodes) { + if (!scoreMap.has(node.id)) { + scoreMap.set(node.id, { graphScore: 0, vectorScore: 0 }); + } + } + } + + const scoredNodes = []; + for (const [nodeId, scores] of scoreMap.entries()) { + const node = getNode(graph, nodeId); + if (!node || node.archived) continue; + const lexicalScore = enableLexicalBoost + ? computeLexicalScore(node, lexicalQuery.sources) + : 0; + const finalScore = hybridScore( + { + graphScore: scores.graphScore, + vectorScore: scores.vectorScore, + lexicalScore, + importance: node.importance, + createdTime: node.createdTime, + }, + { + ...weights, + lexicalWeight: enableLexicalBoost ? lexicalWeight : 0, + }, + ); + + scoredNodes.push({ + nodeId, + node, + graphScore: scores.graphScore, + vectorScore: scores.vectorScore, + lexicalScore, + finalScore, + weightedScore: finalScore, + }); + } + + scoredNodes.sort((left, right) => { + const weightedDelta = + (Number(right.weightedScore) || 0) - (Number(left.weightedScore) || 0); + if (weightedDelta !== 0) return weightedDelta; + const finalDelta = + (Number(right.finalScore) || 0) - (Number(left.finalScore) || 0); + if (finalDelta !== 0) return finalDelta; + const lexicalDelta = + (Number(right.lexicalScore) || 0) - (Number(left.lexicalScore) || 0); + if (lexicalDelta !== 0) return lexicalDelta; + return String(left.nodeId).localeCompare(String(right.nodeId)); + }); + + diagnostics.lexicalBoostedNodes = scoredNodes.filter( + (item) => (Number(item.lexicalScore) || 0) > 0, + ).length; + diagnostics.lexicalTopHits = buildLexicalTopHits(scoredNodes); + + return { + activeNodes, + contextQueryBlend, + queryPlan, + lexicalQuery, + vectorResults, + exactEntityAnchors, + diffusionResults, + scoredNodes, + diagnostics, + }; +} diff --git a/tests/extractor-phase3-layered-context.mjs b/tests/extractor-phase3-layered-context.mjs index 8eade4b..2b9709e 100644 --- a/tests/extractor-phase3-layered-context.mjs +++ b/tests/extractor-phase3-layered-context.mjs @@ -62,7 +62,7 @@ installResolveHooks([ }, ]); -const { createEmptyGraph, addNode, createNode } = await import("../graph/graph.js"); +const { addEdge, addNode, createEdge, createEmptyGraph, createNode } = await import("../graph/graph.js"); const { DEFAULT_NODE_SCHEMA } = await import("../graph/schema.js"); const { extractMemories } = await import("../maintenance/extractor.js"); const { appendSummaryEntry } = await import("../graph/summary-state.js"); @@ -466,6 +466,97 @@ function collectAllPromptContent(captured) { } // ── Test 7: new settings exist in defaults ── +{ + const graph = createEmptyGraph(); + const confessionNode = addNode( + graph, + createNode({ + type: "event", + seq: 3, + importance: 8, + fields: { + title: "中文告白", + summary: "她认真地要求你再说一遍喜欢她。", + }, + }), + ); + const relationshipNode = addNode( + graph, + createNode({ + type: "thread", + seq: 4, + importance: 7, + fields: { + title: "感情升温", + summary: "两人的关系在这次告白后快速拉近。", + }, + }), + ); + addEdge( + graph, + createEdge({ + fromId: confessionNode.id, + toId: relationshipNode.id, + relation: "supports", + strength: 0.9, + }), + ); + + let captured = null; + const restore = setTestOverrides({ + llm: { + async callLLMForJSON(payload) { + captured = payload; + return { operations: [], cognitionUpdates: [], regionUpdates: {} }; + }, + }, + }); + + try { + const result = await extractMemories({ + graph, + messages: [ + { + seq: 10, + role: "user", + content: "中文告白之后,她还是很害羞。", + name: "玩家", + speaker: "玩家", + }, + { + seq: 11, + role: "assistant", + content: "这次中文告白让你们的感情升温了。", + name: "艾琳", + speaker: "艾琳", + }, + ], + startSeq: 10, + endSeq: 11, + schema: DEFAULT_NODE_SCHEMA, + embeddingConfig: null, + settings: { ...defaultSettings }, + }); + + assert.equal(result.success, true); + assert.ok(captured); + + const graphStatsBlock = (Array.isArray(captured.promptMessages) ? captured.promptMessages : []).find( + (m) => m.sourceKey === "graphStats", + ); + assert.ok(graphStatsBlock, "graphStats block should exist"); + const graphStatsContent = String(graphStatsBlock.content || ""); + assert.match(graphStatsContent, /### 图谱节点统计/); + assert.match(graphStatsContent, /事件: 1/); + assert.match(graphStatsContent, /主线: 1/); + assert.match(graphStatsContent, /\[G1\|事件\] 中文告白/); + assert.doesNotMatch(graphStatsContent, new RegExp(confessionNode.id.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"))); + } finally { + restore(); + } +} + +// ── Test 8: new settings exist in defaults ── { assert.equal(defaultSettings.extractRecentMessageCap, 0); assert.equal(defaultSettings.extractPromptStructuredMode, "both"); diff --git a/tests/prompt-node-references.mjs b/tests/prompt-node-references.mjs new file mode 100644 index 0000000..d9d0483 --- /dev/null +++ b/tests/prompt-node-references.mjs @@ -0,0 +1,54 @@ +import assert from "node:assert/strict"; + +const { + createPromptNodeReferenceMap, + getPromptNodeLabel, + resolvePromptNodeId, +} = await import("../prompting/prompt-node-references.js"); + +const rawNodeId = "550e8400-e29b-41d4-a716-446655440000"; +const map = createPromptNodeReferenceMap( + [ + { + nodeId: rawNodeId, + node: { + id: rawNodeId, + type: "event", + fields: { + title: "这是一个非常非常长的节点标题,用于测试提取提示里的标签截断行为", + }, + }, + score: 0.91, + }, + { + node: { + id: "node-2", + type: "thread", + fields: { + summary: "关系持续升温", + }, + }, + score: 0.77, + }, + ], + { + prefix: "G", + maxLength: 12, + buildMeta: ({ entry }) => ({ + score: entry.score, + }), + }, +); + +assert.deepEqual(Object.keys(map.keyToNodeId), ["G1", "G2"]); +assert.equal(map.keyToNodeId.G1, rawNodeId); +assert.equal(map.nodeIdToKey[rawNodeId], "G1"); +assert.equal(resolvePromptNodeId({ nodeId: rawNodeId }), rawNodeId); +assert.equal(resolvePromptNodeId({ node: { id: "node-2" } }), "node-2"); +assert.equal(getPromptNodeLabel({ id: "node-3", fields: { title: "短标题" } }), "短标题"); +assert.equal(map.keyToMeta.G1.score, 0.91); +assert.match(map.keyToMeta.G1.label, /^这是一个非常非常长的节…$/); +assert.equal(map.keyToMeta.G2.label, "关系持续升温"); +assert.equal(map.keyToMeta.G1.nodeId, rawNodeId); + +console.log("prompt-node-references tests passed"); diff --git a/tests/retrieval-config.mjs b/tests/retrieval-config.mjs index 1bcac6d..34bc20d 100644 --- a/tests/retrieval-config.mjs +++ b/tests/retrieval-config.mjs @@ -78,6 +78,485 @@ function createGraphHelpers(graph) { }; } +function getPromptNodeLabel(node = {}, { maxLength = 32 } = {}) { + const raw = String( + node?.fields?.title || + node?.fields?.name || + node?.fields?.summary || + node?.fields?.insight || + node?.fields?.belief || + node?.id || + "—", + ) + .replace(/\s+/g, " ") + .trim(); + if (!raw) return "—"; + if (!Number.isFinite(maxLength) || maxLength < 2 || raw.length <= maxLength) { + return raw; + } + return `${raw.slice(0, Math.max(1, maxLength - 1)).trimEnd()}…`; +} + +function createPromptNodeReferenceMap(entries = [], { prefix = "N", buildMeta = null } = {}) { + const keyToNodeId = {}; + const keyToMeta = {}; + const nodeIdToKey = {}; + const references = []; + for (const [index, entry] of (Array.isArray(entries) ? entries : []).entries()) { + const node = entry?.node || entry || {}; + const nodeId = String(entry?.nodeId || node?.id || "").trim(); + if (!nodeId || nodeIdToKey[nodeId]) continue; + const key = `${String(prefix || "N").trim() || "N"}${references.length + 1}`; + keyToNodeId[key] = nodeId; + nodeIdToKey[nodeId] = key; + keyToMeta[key] = { + nodeId, + type: String(node?.type || ""), + label: getPromptNodeLabel(node), + ...((typeof buildMeta === "function" + ? buildMeta({ entry, node, nodeId, key, index, label: getPromptNodeLabel(node) }) + : {}) || {}), + }; + references.push({ key, nodeId, node, meta: keyToMeta[key] }); + } + return { + keyToNodeId, + keyToMeta, + nodeIdToKey, + references, + }; +} + +function normalizeQueryText(value, maxLength = 400) { + const normalized = String(value ?? "") + .replace(/\r\n/g, "\n") + .replace(/\s+/g, " ") + .trim(); + if (!normalized) return ""; + return normalized.slice(0, Math.max(1, maxLength)); +} + +function splitIntentSegments(text, { maxSegments = 4, minLength = 1 } = {}) { + const raw = String(text || "").trim(); + if (!raw) return []; + const segments = raw + .split(/[,,。.;;!!??\n]+|(?:和|顺便|另外|还有|对了|然后|而且|并且|同时)/) + .map((item) => item.trim()) + .filter((item) => item.length >= minLength); + return uniqueStrings(segments).slice(0, Math.max(1, maxSegments)); +} + +function uniqueStrings(values = [], maxLength = 400) { + const result = []; + const seen = new Set(); + for (const value of values) { + const text = normalizeQueryText(value, maxLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + result.push(text); + } + return result; +} + +function mergeVectorResults(groups, limit) { + const merged = new Map(); + let rawHitCount = 0; + for (const group of groups) { + for (const item of group) { + rawHitCount += 1; + const existing = merged.get(item.nodeId); + if (!existing || item.score > existing.score) { + merged.set(item.nodeId, item); + } + } + } + return { + rawHitCount, + results: [...merged.values()].slice(0, limit), + }; +} + +function parseContextLine(line = "") { + const raw = String(line ?? "").trim(); + if (!raw) return null; + const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i); + if (bracketMatch) { + const role = String(bracketMatch[1] || "").toLowerCase(); + const text = normalizeQueryText(bracketMatch[2] || ""); + return text ? { role, text } : null; + } + const plainMatch = raw.match(/^(user|assistant|用户|助手|ai)\s*[::]\s*([\s\S]*)$/i); + if (!plainMatch) return null; + const roleToken = String(plainMatch[1] || "").toLowerCase(); + const role = + roleToken === "assistant" || roleToken === "助手" || roleToken === "ai" + ? "assistant" + : "user"; + const text = normalizeQueryText(plainMatch[2] || ""); + return text ? { role, text } : null; +} + +function buildContextQueryBlend( + userMessage, + recentMessages = [], + { + enabled = true, + assistantWeight = 0.2, + previousUserWeight = 0.1, + maxTextLength = 400, + } = {}, +) { + const currentText = normalizeQueryText(userMessage, maxTextLength); + let assistantText = ""; + let previousUserText = ""; + const parsedMessages = Array.isArray(recentMessages) + ? recentMessages.map((line) => parseContextLine(line)).filter(Boolean) + : []; + + for (let index = parsedMessages.length - 1; index >= 0; index -= 1) { + const item = parsedMessages[index]; + if (!assistantText && item.role === "assistant") { + assistantText = normalizeQueryText(item.text, maxTextLength); + } + if ( + !previousUserText && + item.role === "user" && + normalizeQueryText(item.text, maxTextLength).toLowerCase() !== + currentText.toLowerCase() + ) { + previousUserText = normalizeQueryText(item.text, maxTextLength); + } + if (assistantText && previousUserText) break; + } + + const currentWeight = Math.max( + 0, + 1 - Number(assistantWeight || 0) - Number(previousUserWeight || 0), + ); + const rawParts = [ + { + kind: "currentUser", + label: "当前用户消息", + text: currentText, + weight: enabled ? currentWeight : 1, + }, + ]; + if (enabled && assistantText) { + rawParts.push({ + kind: "assistantContext", + label: "最近 assistant 回复", + text: assistantText, + weight: Number(assistantWeight || 0), + }); + } + if (enabled && previousUserText) { + rawParts.push({ + kind: "previousUser", + label: "上一条 user 消息", + text: previousUserText, + weight: Number(previousUserWeight || 0), + }); + } + + const dedupedParts = []; + const seen = new Set(); + for (const part of rawParts) { + const text = normalizeQueryText(part.text, maxTextLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + dedupedParts.push({ ...part, text }); + } + + const totalWeight = dedupedParts.reduce( + (sum, part) => sum + Math.max(0, Number(part.weight) || 0), + 0, + ); + const parts = dedupedParts.map((part) => ({ + ...part, + weight: + totalWeight > 0 + ? Math.round((Math.max(0, Number(part.weight) || 0) / totalWeight) * 1000) / + 1000 + : Math.round((1 / Math.max(1, dedupedParts.length)) * 1000) / 1000, + })); + + return { + active: enabled && parts.length > 1, + parts, + currentText: currentText || parts[0]?.text || "", + assistantText, + previousUserText, + combinedText: + parts.length <= 1 + ? parts[0]?.text || "" + : parts.map((part) => `${part.label}:\n${part.text}`).join("\n\n"), + }; +} + +function buildVectorQueryPlan( + blendPlan, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const plan = []; + let currentSegments = []; + for (const part of blendPlan?.parts || []) { + let queries = [part.text]; + if (part.kind === "currentUser" && enableMultiIntent) { + currentSegments = splitIntentSegments(part.text, { maxSegments }); + queries = uniqueStrings([ + part.text, + ...currentSegments.filter((item) => item !== part.text), + ]); + } else { + queries = uniqueStrings([part.text]); + } + plan.push({ + kind: part.kind, + label: part.label, + weight: part.weight, + queries, + }); + } + return { + plan, + currentSegments, + }; +} + +function buildLexicalQuerySources( + userMessage, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const currentText = normalizeQueryText(userMessage, 400); + const segments = enableMultiIntent + ? splitIntentSegments(currentText, { maxSegments }) + : []; + return { + sources: uniqueStrings([currentText, ...segments]), + segments, + }; +} + +function computeLexicalScoreForShared(node, querySources = []) { + const haystack = String( + node?.fields?.name || node?.fields?.title || node?.fields?.summary || "", + ).toLowerCase(); + if (!haystack) return 0; + for (const sourceText of querySources) { + const normalizedSource = String(sourceText || "").toLowerCase(); + if (normalizedSource && haystack.includes(normalizedSource.split(/\s+/)[0])) { + return 1; + } + } + return 0; +} + +function extractEntityAnchors(userMessage, activeNodes = []) { + const anchors = []; + const seen = new Set(); + for (const node of activeNodes) { + const candidates = [node?.fields?.name, node?.fields?.title] + .filter((value) => typeof value === "string") + .map((value) => value.trim()) + .filter((value) => value.length >= 2); + for (const candidate of candidates) { + if (!String(userMessage || "").includes(candidate)) continue; + const key = `${node.id}:${candidate}`; + if (seen.has(key)) continue; + seen.add(key); + anchors.push({ nodeId: node.id, entity: candidate }); + break; + } + } + return anchors; +} + +async function rankNodesForTaskContext({ + graph, + userMessage, + recentMessages = [], + embeddingConfig, + options = {}, +} = {}) { + const activeNodes = Array.isArray(options.activeNodes) + ? options.activeNodes.filter((node) => node && !node.archived) + : (graph?.nodes || []).filter((node) => node && !node.archived); + const topK = Math.max(1, Math.floor(Number(options.topK) || 20)); + const diffusionTopK = Math.max(1, Math.floor(Number(options.diffusionTopK) || 100)); + const enableVectorPrefilter = options.enableVectorPrefilter ?? true; + const enableGraphDiffusion = options.enableGraphDiffusion ?? true; + const enableContextQueryBlend = options.enableContextQueryBlend ?? true; + const enableMultiIntent = options.enableMultiIntent ?? true; + const multiIntentMaxSegments = Math.max( + 1, + Math.floor(Number(options.multiIntentMaxSegments) || 4), + ); + const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, { + enabled: enableContextQueryBlend, + assistantWeight: Number(options.contextAssistantWeight ?? 0.2), + previousUserWeight: Number(options.contextPreviousUserWeight ?? 0.1), + maxTextLength: Number(options.maxTextLength || 400), + }); + const queryPlan = buildVectorQueryPlan(contextQueryBlend, { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }); + const lexicalQuery = buildLexicalQuerySources( + contextQueryBlend.currentText || userMessage, + { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }, + ); + const diagnostics = { + queryBlendActive: contextQueryBlend.active, + queryBlendParts: (contextQueryBlend.parts || []).map((part) => ({ + kind: part.kind, + label: part.label, + weight: part.weight, + text: part.text, + length: part.text.length, + })), + queryBlendWeights: Object.fromEntries( + (contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]), + ), + segmentsUsed: [...(queryPlan.currentSegments || [])], + vectorValidation: { valid: true }, + vectorHits: 0, + vectorMergedHits: 0, + seedCount: 0, + diffusionHits: 0, + temporalSyntheticEdgeCount: 0, + teleportAlpha: Number(options.teleportAlpha ?? 0.15) || 0.15, + lexicalBoostedNodes: 0, + lexicalTopHits: [], + skipReasons: [], + timings: { vector: 0, diffusion: 0 }, + }; + + let vectorResults = []; + if (enableVectorPrefilter) { + const groups = []; + for (const part of queryPlan.plan) { + for (const queryText of part.queries) { + state.vectorCalls.push({ topK, message: queryText }); + const results = [ + { nodeId: "rule-1", score: 0.9 }, + { nodeId: "rule-2", score: 0.8 }, + { nodeId: "rule-3", score: 0.7 }, + ].map((item) => ({ + ...item, + score: item.score * Math.max(0, Number(part.weight) || 0), + })); + groups.push(results); + } + } + const merged = mergeVectorResults(groups, Math.max(topK * 2, 24)); + diagnostics.vectorHits = merged.rawHitCount; + diagnostics.vectorMergedHits = merged.results.length; + vectorResults = merged.results; + } + + const exactEntityAnchors = extractEntityAnchors( + contextQueryBlend.currentText || userMessage, + activeNodes, + ); + let diffusionResults = []; + if (enableGraphDiffusion) { + const seedMap = new Map(); + for (const item of vectorResults) { + seedMap.set(item.nodeId, Math.max(seedMap.get(item.nodeId) || 0, item.score)); + } + for (const item of exactEntityAnchors) { + seedMap.set(item.nodeId, Math.max(seedMap.get(item.nodeId) || 0, 2.0)); + } + const uniqueSeeds = [...seedMap.entries()].map(([id, energy]) => ({ id, energy })); + diagnostics.seedCount = uniqueSeeds.length; + if (uniqueSeeds.length > 0) { + state.diffusionCalls.push({ + seeds: uniqueSeeds, + options: { + maxSteps: 2, + decayFactor: 0.6, + topK: diffusionTopK, + teleportAlpha: diagnostics.teleportAlpha, + }, + }); + diffusionResults = [ + { nodeId: "rule-2", energy: 1.2 }, + { nodeId: "rule-3", energy: 0.9 }, + ]; + } + } + diagnostics.diffusionHits = diffusionResults.length; + + const scoreMap = new Map(); + for (const item of vectorResults) { + scoreMap.set(item.nodeId, { + graphScore: scoreMap.get(item.nodeId)?.graphScore || 0, + vectorScore: item.score, + }); + } + for (const item of diffusionResults) { + scoreMap.set(item.nodeId, { + graphScore: item.energy, + vectorScore: scoreMap.get(item.nodeId)?.vectorScore || 0, + }); + } + if (scoreMap.size === 0) { + for (const node of activeNodes) { + scoreMap.set(node.id, { graphScore: 0, vectorScore: 0 }); + } + } + const scoredNodes = [...scoreMap.entries()].map(([nodeId, scores]) => { + const node = activeNodes.find((item) => item.id === nodeId) || null; + const lexicalScore = computeLexicalScoreForShared(node, lexicalQuery.sources); + return { + nodeId, + node, + graphScore: scores.graphScore, + vectorScore: scores.vectorScore, + lexicalScore, + finalScore: + Number(scores.graphScore || 0) + + Number(scores.vectorScore || 0) + + Number(lexicalScore || 0) + + Number(node?.importance || 0), + weightedScore: + Number(scores.graphScore || 0) + + Number(scores.vectorScore || 0) + + Number(lexicalScore || 0) + + Number(node?.importance || 0), + }; + }); + diagnostics.lexicalBoostedNodes = scoredNodes.filter( + (item) => (Number(item.lexicalScore) || 0) > 0, + ).length; + diagnostics.lexicalTopHits = scoredNodes + .filter((item) => (Number(item.lexicalScore) || 0) > 0) + .slice(0, 5) + .map((item) => ({ + nodeId: item.nodeId, + label: item.node?.fields?.name || item.node?.fields?.title || item.nodeId, + lexicalScore: item.lexicalScore, + finalScore: item.finalScore, + })); + + return { + activeNodes, + contextQueryBlend, + queryPlan, + lexicalQuery, + vectorResults, + exactEntityAnchors, + diffusionResults, + scoredNodes, + diagnostics, + }; +} + const schema = [{ id: "rule", label: "规则", alwaysInject: false }]; const state = { @@ -93,6 +572,9 @@ const graph = createGraph(); const helpers = createGraphHelpers(graph); const retrieve = await loadRetrieve({ ...helpers, + createPromptNodeReferenceMap, + getPromptNodeLabel, + rankNodesForTaskContext, STORY_TEMPORAL_BUCKETS: { CURRENT: "current", ADJACENT_PAST: "adjacentPast", diff --git a/tests/shared-ranking.mjs b/tests/shared-ranking.mjs new file mode 100644 index 0000000..3781175 --- /dev/null +++ b/tests/shared-ranking.mjs @@ -0,0 +1,175 @@ +import assert from "node:assert/strict"; +import { + installResolveHooks, + toDataModuleUrl, +} from "./helpers/register-hooks-compat.mjs"; + +const extensionsShimSource = [ + "export const extension_settings = {};", + "export function getContext() {", + " return {", + " chat: [],", + " chatMetadata: {},", + " extensionSettings: {},", + " powerUserSettings: {},", + " characters: {},", + " characterId: null,", + " name1: '玩家',", + " name2: '艾琳',", + " chatId: 'test-chat',", + " };", + "}", +].join("\n"); + +const scriptShimSource = [ + "export function getRequestHeaders() {", + " return {};", + "}", + "export function substituteParamsExtended(value) {", + " return String(value ?? '');", + "}", +].join("\n"); + +installResolveHooks([ + { + specifiers: [ + "../../../extensions.js", + "../../../../extensions.js", + "../../../../../extensions.js", + ], + url: toDataModuleUrl(extensionsShimSource), + }, + { + specifiers: [ + "../../../../script.js", + "../../../../../script.js", + ], + url: toDataModuleUrl(scriptShimSource), + }, +]); + +const { addEdge, addNode, createEdge, createEmptyGraph, createNode } = await import( + "../graph/graph.js" +); +const { rankNodesForTaskContext } = await import("../retrieval/shared-ranking.js"); + +function setTestOverrides(overrides = {}) { + globalThis.__stBmeTestOverrides = overrides; + return () => { + delete globalThis.__stBmeTestOverrides; + }; +} + +const graph = createEmptyGraph(); +const confession = addNode( + graph, + createNode({ + type: "event", + seq: 10, + importance: 8, + fields: { + title: "中文告白", + summary: "她认真地说喜欢你,并要求你再说一遍。", + }, + }), +); +const dateEvent = addNode( + graph, + createNode({ + type: "event", + seq: 11, + importance: 4, + fields: { + title: "节日约会", + summary: "她们一起逛街吃饭。", + }, + }), +); +const relationship = addNode( + graph, + createNode({ + type: "thread", + seq: 12, + importance: 7, + fields: { + title: "感情升温", + summary: "两人的恋爱关系快速升温。", + }, + }), +); +confession.embedding = [1, 0.3, 0.1]; +dateEvent.embedding = [0.2, 0.9, 0.1]; +relationship.embedding = [0.8, 0.6, 0.2]; +addEdge( + graph, + createEdge({ + fromId: confession.id, + toId: relationship.id, + relation: "supports", + strength: 0.9, + }), +); + +const graphBefore = JSON.stringify(graph); +const restore = setTestOverrides({ + embedding: { + async embedText() { + return [1, 0.5, 0.25]; + }, + searchSimilar(_queryVec, candidates) { + assert.ok(candidates.some((item) => item.nodeId === confession.id)); + return [ + { nodeId: confession.id, score: 0.97 }, + { nodeId: dateEvent.id, score: 0.23 }, + ]; + }, + }, +}); + +try { + const config = { + mode: "direct", + source: "direct", + apiUrl: "https://example.com/v1", + apiKey: "", + model: "test-embedding", + }; + const first = await rankNodesForTaskContext({ + graph, + userMessage: "[user]: 中文告白后的关系进展", + embeddingConfig: config, + options: { + enableContextQueryBlend: false, + topK: 8, + diffusionTopK: 16, + }, + }); + const second = await rankNodesForTaskContext({ + graph, + userMessage: "[user]: 中文告白后的关系进展", + embeddingConfig: config, + options: { + enableContextQueryBlend: false, + topK: 8, + diffusionTopK: 16, + }, + }); + + assert.equal(JSON.stringify(graph), graphBefore, "shared ranking should be side-effect-free"); + assert.equal(first.scoredNodes[0]?.nodeId, confession.id); + assert.equal(second.scoredNodes[0]?.nodeId, confession.id); + assert.deepEqual( + first.scoredNodes.map((item) => item.nodeId), + second.scoredNodes.map((item) => item.nodeId), + "ranking order should stay deterministic under fixed inputs", + ); + const propagated = first.scoredNodes.find((item) => item.nodeId === relationship.id); + assert.ok(propagated, "diffusion should surface connected relationship node"); + assert.ok((Number(propagated?.graphScore) || 0) > 0, "connected node should receive graph diffusion score"); + assert.equal(first.diagnostics.vectorMergedHits, 2); + assert.ok(first.diagnostics.diffusionHits >= 1); +} finally { + restore(); +} + +console.log("shared-ranking tests passed");