diff --git a/index.js b/index.js index 442e56c..9d00f88 100644 --- a/index.js +++ b/index.js @@ -11126,6 +11126,8 @@ async function handleExtractionSuccess( await generateReflection({ graph: currentGraph, currentSeq: endIdx, + schema: getSchema(), + embeddingConfig: getEmbeddingConfig(), settings, signal, }); diff --git a/maintenance/chat-history.js b/maintenance/chat-history.js index 4595e9c..a0e2a14 100644 --- a/maintenance/chat-history.js +++ b/maintenance/chat-history.js @@ -286,6 +286,7 @@ export function buildExtractionMessages(chat, startIdx, endIdx, settings) { rawContent: String(msg?.mes ?? ""), name: String(msg?.name ?? "").trim(), speaker: String(msg?.name ?? "").trim(), + isContextOnly: index < startIdx, }); } diff --git a/maintenance/compressor.js b/maintenance/compressor.js index 073a2d2..8adc669 100644 --- a/maintenance/compressor.js +++ b/maintenance/compressor.js @@ -30,6 +30,7 @@ import { } from "../prompting/prompt-builder.js"; import { getSTContextForPrompt } from "../host/st-context.js"; import { applyTaskRegex } from "../prompting/task-regex.js"; +import { buildTaskGraphStats } from "./task-graph-stats.js"; import { isDirectVectorConfig } from "../vector/vector-index.js"; function createAbortError(message = "操作已终止") { @@ -109,6 +110,30 @@ function normalizeCompressionFieldValue(value) { return String(value).trim(); } +function buildCompressionRankingQueryText(nodes = [], typeDef = {}) { + const typeLabel = String(typeDef?.label || typeDef?.id || "节点").trim() || "节点"; + const lines = (Array.isArray(nodes) ? nodes : []) + .map((node, index) => { + const fieldsText = Object.entries(node?.fields || {}) + .map(([key, value]) => { + const normalizedValue = normalizeCompressionFieldValue(value); + return normalizedValue ? `${key}: ${normalizedValue}` : ""; + }) + .filter(Boolean) + .join(" | "); + const storyTimeLabel = describeNodeStoryTime(node); + return [ + `${typeLabel}#${index + 1}`, + storyTimeLabel ? `剧情时间=${storyTimeLabel}` : "", + fieldsText, + ] + .filter(Boolean) + .join(" | "); + }) + .filter(Boolean); + return lines.length > 0 ? [`压缩批次 ${typeLabel}`, ...lines].join("\n") : ""; +} + function buildCompressionFallbackSummary(batch = []) { return batch .map((node) => @@ -187,6 +212,7 @@ export async function compressType({ graph, typeDef, embeddingConfig, + schema = [], force = false, customPrompt, signal, @@ -211,6 +237,7 @@ export async function compressType({ typeDef, level, embeddingConfig, + schema, force, customPrompt, signal, @@ -235,6 +262,7 @@ async function compressLevel({ typeDef, level, embeddingConfig, + schema = [], force, customPrompt, signal, @@ -271,6 +299,9 @@ async function compressLevel({ const summaryResult = await summarizeBatch( batch, typeDef, + graph, + embeddingConfig, + schema, customPrompt, signal, settings, @@ -476,6 +507,9 @@ function migrateBatchEdges(graph, batch, compressedNode) { async function summarizeBatch( nodes, typeDef, + graph, + embeddingConfig, + schema = [], customPrompt, signal, settings = {}, @@ -493,13 +527,35 @@ async function summarizeBatch( const instruction = typeDef.compression.instruction || "将以下节点压缩总结为一条精炼记录。"; + const excludedNodeIds = new Set( + (Array.isArray(nodes) ? nodes : []).map((node) => String(node?.id || "").trim()), + ); + const compressionGraphStats = await buildTaskGraphStats({ + graph, + schema: Array.isArray(schema) && schema.length > 0 ? schema : [typeDef], + userMessage: buildCompressionRankingQueryText(nodes, typeDef), + recentMessages: [], + embeddingConfig, + signal, + activeNodes: getActiveNodes(graph).filter( + (node) => !excludedNodeIds.has(String(node?.id || "").trim()), + ), + rankingOptions: { + topK: 12, + diffusionTopK: 48, + enableContextQueryBlend: false, + enableMultiIntent: true, + maxTextLength: 1200, + }, + relevantHeading: "与当前压缩批次最相关的既有节点", + }); const compressPromptBuild = await buildTaskPrompt(settings, "compress", { taskName: "compress", nodeContent: nodeDescriptions, candidateNodes: nodeDescriptions, currentRange: `${nodes[0]?.seq ?? "?"} ~ ${nodes[nodes.length - 1]?.seq ?? "?"}`, - graphStats: `node_count=${nodes.length}, node_type=${typeDef.id}`, + graphStats: compressionGraphStats.graphStats, ...getSTContextForPrompt(), }); const compressRegexInput = { entries: [] }; @@ -581,6 +637,7 @@ export async function compressAll( graph, typeDef, embeddingConfig, + schema, force, customPrompt, signal, diff --git a/maintenance/consolidator.js b/maintenance/consolidator.js index a965303..e450207 100644 --- a/maintenance/consolidator.js +++ b/maintenance/consolidator.js @@ -22,6 +22,7 @@ import { } from "../prompting/prompt-builder.js"; import { getSTContextForPrompt } from "../host/st-context.js"; import { applyTaskRegex } from "../prompting/task-regex.js"; +import { buildTaskGraphStats } from "./task-graph-stats.js"; import { buildNodeVectorText, findSimilarNodesByText, @@ -132,6 +133,27 @@ function canMergeTemporalScopedMemories(leftNode, rightNode) { return isStoryTimeCompatible(leftNode, rightNode).compatible; } +function buildConsolidationRankingQueryText(newEntries = []) { + return (Array.isArray(newEntries) ? newEntries : []) + .map((entry, index) => { + const node = entry?.node; + const fieldsText = Object.entries(node?.fields || {}) + .map(([key, value]) => `${key}: ${value}`) + .join(", "); + const storyTimeLabel = describeNodeStoryTime(node); + return [ + `新记忆#${index + 1}`, + `类型=${String(node?.type || "").trim()}`, + storyTimeLabel ? `剧情时间=${storyTimeLabel}` : "", + fieldsText, + ] + .filter(Boolean) + .join(" | "); + }) + .filter(Boolean) + .join("\n"); +} + export async function analyzeAutoConsolidationGate({ graph, newNodeIds, @@ -297,6 +319,7 @@ export async function consolidateMemories({ graph, newNodeIds, embeddingConfig, + schema = [], options = {}, customPrompt, signal, @@ -491,13 +514,33 @@ export async function consolidateMemories({ } const userPrompt = userPromptSections.join("\n\n"); + const newNodeIdSet = new Set(newEntries.map((entry) => String(entry?.id || "").trim())); + const consolidationGraphStats = await buildTaskGraphStats({ + graph, + schema, + userMessage: buildConsolidationRankingQueryText(newEntries), + recentMessages: [], + embeddingConfig, + signal, + activeNodes: activeNodes.filter( + (node) => !newNodeIdSet.has(String(node?.id || "").trim()), + ), + rankingOptions: { + topK: 12, + diffusionTopK: 48, + enableContextQueryBlend: false, + enableMultiIntent: true, + maxTextLength: 1200, + }, + relevantHeading: "与本轮整合最相关的既有节点", + }); let decision; const consolidationPromptBuild = await buildTaskPrompt(settings, "consolidation", { taskName: "consolidation", candidateNodes: userPrompt, candidateText: userPrompt, - graphStats: `new_entries=${newEntries.length}, threshold=${conflictThreshold}`, + graphStats: consolidationGraphStats.graphStats, ...getSTContextForPrompt(), }); const consolidationRegexInput = { entries: [] }; diff --git a/maintenance/extraction-context.js b/maintenance/extraction-context.js index 201f05d..15eecec 100644 --- a/maintenance/extraction-context.js +++ b/maintenance/extraction-context.js @@ -289,6 +289,32 @@ function resolveSpeakerName(message = {}, role = "assistant", names = {}) { return role || "assistant"; } +function shouldHideSpeakerLabel(message = {}, role = "assistant", names = {}) { + if (message?.hideSpeakerLabel === true) { + return true; + } + if (message?.hideSpeakerLabel === false) { + return false; + } + if (role !== "assistant") { + return false; + } + if (String(message?.source || "").trim() === "worldInfo-atDepth") { + return false; + } + const explicitSpeaker = String( + message?.speaker ?? message?.name ?? message?.displayName ?? "", + ).trim(); + if (!explicitSpeaker) { + return true; + } + const activeCharName = String(names?.charName || "").trim(); + if (!activeCharName) { + return false; + } + return explicitSpeaker === activeCharName; +} + function normalizeExtractionMessage(message = {}, index = 0, names = {}) { const role = normalizeRole( message?.role ?? (message?.is_user === true ? "user" : "assistant"), @@ -296,6 +322,7 @@ function normalizeExtractionMessage(message = {}, index = 0, names = {}) { const content = String(resolveMessageContent(message) || "").trim(); const rawContent = String(resolveMessageRawContent(message) || content).trim(); const speaker = resolveSpeakerName(message, role, names); + const hideSpeakerLabel = shouldHideSpeakerLabel(message, role, names); const seq = Number.isFinite(Number(message?.seq)) ? Number(message.seq) : null; return { @@ -304,9 +331,11 @@ function normalizeExtractionMessage(message = {}, index = 0, names = {}) { role, speaker, name: speaker, + hideSpeakerLabel, content, rawContent, sourceType: role === "user" ? "user_input" : "ai_output", + isContextOnly: message?.isContextOnly === true, }; } @@ -322,18 +351,39 @@ function countRoles(messages = []) { } export function formatExtractionTranscript(messages = []) { - return (Array.isArray(messages) ? messages : []) - .map((message, index) => { - const seqLabel = Number.isFinite(Number(message?.seq)) - ? `#${Number(message.seq)}` - : `#${index + 1}`; - const role = normalizeRole(message?.role || "assistant"); - const speaker = String(message?.speaker || message?.name || "").trim(); - const speakerLabel = speaker ? `|${speaker}` : ""; - return `${seqLabel} [${role}${speakerLabel}]: ${String(message?.content || "")}`; - }) - .filter((item) => String(item || "").trim()) - .join("\n\n"); + const safeMessages = Array.isArray(messages) ? messages : []; + const hasContextMessages = safeMessages.some((m) => m?.isContextOnly === true); + const hasTargetMessages = safeMessages.some((m) => m?.isContextOnly !== true); + const lines = []; + let inContext = null; + + for (let index = 0; index < safeMessages.length; index += 1) { + const message = safeMessages[index]; + const isContext = message?.isContextOnly === true; + + if (hasContextMessages && hasTargetMessages && isContext !== inContext) { + if (isContext) { + lines.push("--- 以下是上下文回顾(已提取过),仅供理解剧情 ---"); + } else { + lines.push("--- 以下是本次需要提取记忆的新对话内容 ---"); + } + inContext = isContext; + } + + const seqLabel = Number.isFinite(Number(message?.seq)) + ? `#${Number(message.seq)}` + : `#${index + 1}`; + const role = normalizeRole(message?.role || "assistant"); + const speaker = String(message?.speaker || message?.name || "").trim(); + const speakerLabel = + message?.hideSpeakerLabel === true || !speaker ? "" : `|${speaker}`; + const line = `${seqLabel} [${role}${speakerLabel}]: ${String(message?.content || "")}`; + if (String(line || "").trim()) { + lines.push(line); + } + } + + return lines.join("\n\n"); } export function buildExtractionInputContext( diff --git a/maintenance/extractor.js b/maintenance/extractor.js index efafd7b..6186755 100644 --- a/maintenance/extractor.js +++ b/maintenance/extractor.js @@ -16,7 +16,7 @@ import { updateNode, } from "../graph/graph.js"; import { callLLMForJSON } from "../llm/llm.js"; -import { ensureEventTitle, getNodeDisplayName } from "../graph/node-labels.js"; +import { ensureEventTitle } from "../graph/node-labels.js"; import { normalizeMemoryScope, isObjectiveScope, @@ -45,6 +45,7 @@ import { RELATION_TYPES } from "../graph/schema.js"; import { applyTaskRegex } from "../prompting/task-regex.js"; import { getSTContextForPrompt, getSTContextSnapshot } from "../host/st-context.js"; import { buildExtractionInputContext } from "./extraction-context.js"; +import { buildTaskGraphStats } from "./task-graph-stats.js"; import { aliasSetMatchesValue, buildUserPovAliasNormalizedSet, @@ -148,6 +149,46 @@ function resolveExtractPromptStructuredMode(settings) { return "both"; } +function formatExtractRankingMessage(message = {}) { + const role = String(message?.role || "assistant").trim().toLowerCase() === "user" + ? "user" + : "assistant"; + const content = String(message?.content || "").trim(); + if (!content) return ""; + return `[${role}]: ${content}`; +} + +function buildExtractRankingQueryText(messages = []) { + const normalizedMessages = Array.isArray(messages) ? messages : []; + const targetLines = normalizedMessages + .filter((message) => message?.isContextOnly !== true) + .map((message) => formatExtractRankingMessage(message)) + .filter(Boolean); + if (targetLines.length > 0) { + return targetLines.join("\n"); + } + return normalizedMessages + .map((message) => formatExtractRankingMessage(message)) + .filter(Boolean) + .join("\n"); +} + +function buildReflectionRankingQueryText({ + eventSummary = "", + characterSummary = "", + threadSummary = "", + contradictionSummary = "", +} = {}) { + return [ + eventSummary ? `最近事件:\n${eventSummary}` : "", + characterSummary ? `近期角色状态:\n${characterSummary}` : "", + threadSummary ? `当前主线:\n${threadSummary}` : "", + contradictionSummary ? `已知矛盾:\n${contradictionSummary}` : "", + ] + .filter(Boolean) + .join("\n\n"); +} + function isAbortError(error) { return error?.name === "AbortError"; } @@ -873,6 +914,8 @@ export async function extractMemories({ content: message?.content, speaker: message?.speaker, name: message?.name, + hideSpeakerLabel: message?.hideSpeakerLabel === true, + isContextOnly: message?.isContextOnly === true, })) : []; @@ -898,8 +941,26 @@ export async function extractMemories({ ? dialogueText : structuredMessages; - // 构建当前图概览(让 LLM 知道已有哪些节点,避免重复) - const graphOverview = buildGraphOverview(graph, schema); + const extractGraphRankingQuery = buildExtractRankingQueryText(structuredMessages); + const extractGraphStats = await buildTaskGraphStats({ + graph, + schema, + userMessage: extractGraphRankingQuery, + recentMessages: [], + embeddingConfig, + signal, + rankingOptions: { + topK: 12, + diffusionTopK: 48, + enableContextQueryBlend: false, + enableMultiIntent: true, + maxTextLength: 1200, + }, + relevantHeading: "与当前提取片段最相关的既有节点", + }); + const extractGraphRanking = extractGraphStats.ranking; + const extractGraphRelevantNodes = extractGraphStats.relevantReferenceMap; + const graphOverview = extractGraphStats.graphStats; // 构建 Schema 描述 const schemaDescription = buildSchemaDescription(schema); @@ -922,6 +983,14 @@ export async function extractMemories({ `storyTimeContext=${storyTimeContext ? "present" : "none"}, ` + `worldbookMode=${String(settings?.extractWorldbookMode || "active")}`, ); + if (extractGraphRanking) { + debugLog( + `[ST-BME][extract-graph] relevantNodes=${extractGraphRelevantNodes.references.length}, ` + + `vectorMergedHits=${Number(extractGraphRanking?.diagnostics?.vectorMergedHits || 0)}, ` + + `diffusionHits=${Number(extractGraphRanking?.diagnostics?.diffusionHits || 0)}, ` + + `lexicalBoostedNodes=${Number(extractGraphRanking?.diagnostics?.lexicalBoostedNodes || 0)}`, + ); + } const extractWorldbookMode = String(settings?.extractWorldbookMode || "active").trim().toLowerCase(); const promptBuild = await buildTaskPrompt(settings, "extract", { @@ -957,15 +1026,39 @@ export async function extractMemories({ // 用户提示词 — Phase 3 分层信息结构 const userPromptSections = []; - // Layer 1: 当前对话切片 - if (dialogueText) { - userPromptSections.push("## 当前对话内容(需提取记忆)", dialogueText, ""); - } else if (structuredMode === "structured" && structuredMessages.length > 0) { - userPromptSections.push( - "## 当前对话内容(结构化消息,需提取记忆)", - "(结构化消息已通过 profile blocks 注入,请参考上方 recentMessages 块。)", - "", - ); + // Layer 1: 当前对话切片(区分上下文回顾 vs 提取目标) + { + const hasContextMessages = structuredMessages.some((m) => m?.isContextOnly === true); + const hasTargetMessages = structuredMessages.some((m) => m?.isContextOnly !== true); + if (dialogueText) { + if (hasContextMessages && hasTargetMessages) { + userPromptSections.push( + "## 对话内容", + "以下对话包含两部分:已提取过的上下文回顾(仅供理解前情)和本次需要提取记忆的新内容。" + + "请**只从新内容中提取记忆**,不要重复提取上下文回顾中已有的信息。", + dialogueText, + "", + ); + } else { + userPromptSections.push("## 当前对话内容(需提取记忆)", dialogueText, ""); + } + } else if (structuredMode === "structured" && structuredMessages.length > 0) { + if (hasContextMessages && hasTargetMessages) { + userPromptSections.push( + "## 对话内容(结构化消息)", + "以下结构化消息包含两部分:标记为 isContextOnly 的是已提取过的上下文回顾(仅供理解前情)," + + "其余是本次需要提取记忆的新内容。请**只从 isContextOnly 为 false 的消息中提取记忆**。" + + "(结构化消息已通过 profile blocks 注入,请参考上方 recentMessages 块。)", + "", + ); + } else { + userPromptSections.push( + "## 当前对话内容(结构化消息,需提取记忆)", + "(结构化消息已通过 profile blocks 注入,请参考上方 recentMessages 块。)", + "", + ); + } + } } // Layer 2: 当前图谱状态 @@ -1724,30 +1817,6 @@ async function generateNodeEmbeddings(graph, embeddingConfig, signal) { } } -/** - * 构建图谱概览文本(给 LLM 看) - */ -function buildGraphOverview(graph, schema) { - const activeNodes = graph.nodes - .filter((n) => !n.archived) - .sort((a, b) => (a.seq || 0) - (b.seq || 0)); - if (activeNodes.length === 0) return ""; - - const lines = []; - for (const typeDef of schema) { - const nodesOfType = activeNodes.filter((n) => n.type === typeDef.id); - if (nodesOfType.length === 0) continue; - - lines.push(`### ${typeDef.label} (${nodesOfType.length} 个节点)`); - for (const node of nodesOfType.slice(-10)) { - // 只展示最近 10 个 - lines.push(` - [${node.id}] ${getNodeDisplayName(node)}`); - } - } - - return lines.join("\n"); -} - /** * 构建 Schema 描述文本 */ @@ -2015,6 +2084,8 @@ export async function generateSynopsis({ export async function generateReflection({ graph, currentSeq, + schema = [], + embeddingConfig, customPrompt, signal, settings = {}, @@ -2064,6 +2135,27 @@ export async function generateReflection({ const contradictionSummary = contradictEdges .map((e) => `${e.fromId} -> ${e.toId} (${e.relation})`) .join("\n"); + const reflectionGraphStats = await buildTaskGraphStats({ + graph, + schema, + userMessage: buildReflectionRankingQueryText({ + eventSummary, + characterSummary, + threadSummary, + contradictionSummary, + }), + recentMessages: [], + embeddingConfig, + signal, + rankingOptions: { + topK: 12, + diffusionTopK: 48, + enableContextQueryBlend: false, + enableMultiIntent: true, + maxTextLength: 1200, + }, + relevantHeading: "与当前反思最相关的既有节点", + }); const reflectionPromptBuild = await buildTaskPrompt(settings, "reflection", { taskName: "reflection", @@ -2071,7 +2163,7 @@ export async function generateReflection({ characterSummary: characterSummary || "(无)", threadSummary: threadSummary || "(无)", contradictionSummary: contradictionSummary || "(无)", - graphStats: `event=${recentEvents.length}, character=${recentCharacters.length}, thread=${recentThreads.length}`, + graphStats: reflectionGraphStats.graphStats, ...getSTContextForPrompt(), }); const reflectionRegexInput = { entries: [] }; diff --git a/maintenance/task-graph-stats.js b/maintenance/task-graph-stats.js new file mode 100644 index 0000000..15727e2 --- /dev/null +++ b/maintenance/task-graph-stats.js @@ -0,0 +1,203 @@ +import { getActiveNodes } from "../graph/graph.js"; +import { createPromptNodeReferenceMap } from "../prompting/prompt-node-references.js"; +import { rankNodesForTaskContext } from "../retrieval/shared-ranking.js"; + +const DEFAULT_TYPE_LABELS = Object.freeze({ + event: "事件", + character: "角色", + location: "地点", + rule: "规则", + thread: "主线", + synopsis: "全局概要", + reflection: "反思", + pov_memory: "主观记忆", +}); + +function createTypeLabelMap(schema = []) { + return new Map( + (Array.isArray(schema) ? schema : []) + .filter((typeDef) => String(typeDef?.id || "").trim()) + .map((typeDef) => [ + String(typeDef?.id || "").trim(), + String(typeDef?.label || typeDef?.id || "").trim(), + ]), + ); +} + +function resolveTypeLabel(typeId = "", typeLabelMap = new Map()) { + const normalizedTypeId = String(typeId || "").trim(); + return ( + typeLabelMap.get(normalizedTypeId) || + DEFAULT_TYPE_LABELS[normalizedTypeId] || + normalizedTypeId || + "节点" + ); +} + +function listGraphTypeCounts(activeNodes = [], schema = [], typeLabelMap = new Map()) { + const safeActiveNodes = Array.isArray(activeNodes) ? activeNodes : []; + if (Array.isArray(schema) && schema.length > 0) { + return schema + .map((typeDef) => { + const typeId = String(typeDef?.id || "").trim(); + const count = safeActiveNodes.filter((node) => node?.type === typeId).length; + return { + typeId, + label: resolveTypeLabel(typeId, typeLabelMap), + count, + }; + }) + .filter((entry) => entry.count > 0); + } + + const countMap = new Map(); + for (const node of safeActiveNodes) { + const typeId = String(node?.type || "").trim(); + if (!typeId) continue; + countMap.set(typeId, (countMap.get(typeId) || 0) + 1); + } + return [...countMap.entries()] + .map(([typeId, count]) => ({ + typeId, + label: resolveTypeLabel(typeId, typeLabelMap), + count, + })) + .sort((left, right) => left.typeId.localeCompare(right.typeId)); +} + +export function buildRelevantNodeReferenceMap( + scoredNodes = [], + schema = [], + { + maxCount = 6, + prefix = "G", + maxLength = 28, + } = {}, +) { + const typeLabelMap = createTypeLabelMap(schema); + const relevantNodes = (Array.isArray(scoredNodes) ? scoredNodes : []) + .filter( + (entry) => + entry?.node && + !entry.node.archived && + ((Number(entry?.vectorScore) || 0) > 0 || + (Number(entry?.graphScore) || 0) > 0 || + (Number(entry?.lexicalScore) || 0) > 0), + ) + .slice(0, Math.max(1, maxCount)); + + return createPromptNodeReferenceMap(relevantNodes, { + prefix, + maxLength, + buildMeta: ({ entry, node }) => ({ + typeLabel: resolveTypeLabel(node?.type, typeLabelMap), + score: + Math.round((Number(entry?.weightedScore ?? entry?.finalScore) || 0) * 1000) / + 1000, + }), + }); +} + +export function buildGraphOverview( + graph, + schema = [], + relevantReferenceMap = null, + { + relevantHeading = "与当前任务最相关的既有节点", + } = {}, +) { + const activeNodes = graph?.nodes + ?.filter((node) => node && !node.archived) + ?.sort((left, right) => (left.seq || 0) - (right.seq || 0)); + if (!Array.isArray(activeNodes) || activeNodes.length === 0) { + return ""; + } + + const typeLabelMap = createTypeLabelMap(schema); + const typeCounts = listGraphTypeCounts(activeNodes, schema, typeLabelMap); + const lines = ["### 图谱节点统计"]; + + for (const entry of typeCounts) { + lines.push(` - ${entry.label}: ${entry.count}`); + } + + const references = Array.isArray(relevantReferenceMap?.references) + ? relevantReferenceMap.references + : []; + if (references.length > 0) { + lines.push("", `### ${String(relevantHeading || "与当前任务最相关的既有节点").trim() || "与当前任务最相关的既有节点"}`); + for (const reference of references) { + const typeLabel = + String(reference?.meta?.typeLabel || reference?.meta?.type || "节点").trim() || + "节点"; + const label = String(reference?.meta?.label || "—").trim() || "—"; + const score = Number(reference?.meta?.score || 0).toFixed(3); + lines.push(` - [${reference.key}|${typeLabel}] ${label} (score=${score})`); + } + } + + return lines.join("\n"); +} + +function normalizeActiveNodes(graph, activeNodes = null) { + if (Array.isArray(activeNodes)) { + return activeNodes.filter((node) => node && !node.archived); + } + return getActiveNodes(graph).filter((node) => node && !node.archived); +} + +export async function buildTaskGraphStats({ + graph, + schema = [], + userMessage = "", + recentMessages = [], + embeddingConfig, + signal, + activeNodes = null, + rankingOptions = {}, + relevantHeading = "与当前任务最相关的既有节点", + maxRelevantNodes = 6, + prefix = "G", + maxLabelLength = 28, +} = {}) { + const normalizedActiveNodes = normalizeActiveNodes(graph, activeNodes); + const normalizedUserMessage = String(userMessage || "").trim(); + + let ranking = null; + if (graph && normalizedActiveNodes.length > 0 && normalizedUserMessage) { + ranking = await rankNodesForTaskContext({ + graph, + userMessage: normalizedUserMessage, + recentMessages, + embeddingConfig, + signal, + options: { + activeNodes: normalizedActiveNodes, + topK: 12, + diffusionTopK: 48, + enableContextQueryBlend: false, + enableMultiIntent: true, + maxTextLength: 1200, + ...rankingOptions, + }, + }); + } + + const relevantReferenceMap = buildRelevantNodeReferenceMap( + ranking?.scoredNodes, + schema, + { + maxCount: maxRelevantNodes, + prefix, + maxLength: maxLabelLength, + }, + ); + + return { + ranking, + relevantReferenceMap, + graphStats: buildGraphOverview(graph, schema, relevantReferenceMap, { + relevantHeading, + }), + }; +} diff --git a/manifest.json b/manifest.json index 41dcfde..be58f68 100644 --- a/manifest.json +++ b/manifest.json @@ -6,6 +6,6 @@ "js": "index.js", "css": "style.css", "author": "Youzini", - "version": "4.6.6", + "version": "4.7.4", "homePage": "https://github.com/Youzini-afk/ST-Bionic-Memory-Ecology" } diff --git a/prompting/prompt-builder.js b/prompting/prompt-builder.js index 7b84777..faaa446 100644 --- a/prompting/prompt-builder.js +++ b/prompting/prompt-builder.js @@ -296,6 +296,9 @@ function getPromptMessageLikeDescriptor(value) { role: role === "user" ? "user" : "assistant", seq: getOptionalFiniteNumber(value.seq), speaker, + hideSpeakerLabel: value?.hideSpeakerLabel === true, + isContextOnly: + typeof value.isContextOnly === "boolean" ? value.isContextOnly : null, }; } @@ -308,6 +311,9 @@ function getPromptMessageLikeDescriptor(value) { role: value.is_user === true ? "user" : "assistant", seq: getOptionalFiniteNumber(value.seq), speaker, + hideSpeakerLabel: value?.hideSpeakerLabel === true, + isContextOnly: + typeof value.isContextOnly === "boolean" ? value.isContextOnly : null, }; } @@ -322,23 +328,62 @@ function isPromptMessageArray(value) { ); } +export const EXTRACTION_CONTEXT_REVIEW_HEADER = + "--- 以下是上下文回顾(已提取过),仅供理解剧情 ---"; +export const EXTRACTION_TARGET_CONTENT_HEADER = + "--- 以下是本次需要提取记忆的新对话内容 ---"; +export const RECALL_TARGET_CONTENT_HEADER = + "--- 以下是本次需要召回记忆的新对话内容 ---"; + +function getPromptMessageContextGroup(value) { + const descriptor = getPromptMessageLikeDescriptor(value); + if (!descriptor || typeof descriptor.isContextOnly !== "boolean") { + return null; + } + return descriptor.isContextOnly ? "context" : "target"; +} + +function getPromptMessageContextHeader(group = "") { + if (group === "context") { + return EXTRACTION_CONTEXT_REVIEW_HEADER; + } + if (group === "target") { + return EXTRACTION_TARGET_CONTENT_HEADER; + } + return ""; +} + function formatPromptMessageTranscript(value) { const entries = Array.isArray(value) ? value : [value]; - return entries - .map((entry, index) => { - const descriptor = getPromptMessageLikeDescriptor(entry); - if (!descriptor) { - return ""; - } - const seqLabel = - descriptor.seq != null ? `#${descriptor.seq}` : `#${index + 1}`; - const speakerLabel = descriptor.speaker - ? `|${descriptor.speaker}` - : ""; - return `${seqLabel} [${descriptor.role}${speakerLabel}]: ${descriptor.content}`; - }) - .filter(Boolean) - .join("\n\n"); + const hasContextMessages = entries.some( + (entry) => getPromptMessageContextGroup(entry) === "context", + ); + const hasTargetMessages = entries.some( + (entry) => getPromptMessageContextGroup(entry) === "target", + ); + const lines = []; + let activeGroup = null; + + for (let index = 0; index < entries.length; index += 1) { + const entry = entries[index]; + const descriptor = getPromptMessageLikeDescriptor(entry); + if (!descriptor) { + continue; + } + const group = getPromptMessageContextGroup(entry); + if (hasContextMessages && hasTargetMessages && group && group !== activeGroup) { + lines.push(getPromptMessageContextHeader(group)); + activeGroup = group; + } + const seqLabel = + descriptor.seq != null ? `#${descriptor.seq}` : `#${index + 1}`; + const speakerLabel = !descriptor.hideSpeakerLabel && descriptor.speaker + ? `|${descriptor.speaker}` + : ""; + lines.push(`${seqLabel} [${descriptor.role}${speakerLabel}]: ${descriptor.content}`); + } + + return lines.filter(Boolean).join("\n\n"); } function stringifyInterpolatedValue(value) { @@ -1880,6 +1925,91 @@ function clonePayloadMessage(message = {}) { }); } +function splitSectionedTranscriptPayloadMessage(message = {}) { + const normalizedRole = normalizeRole(message?.role); + const sourceKey = String(message?.sourceKey || "").trim(); + const content = String(message?.content || "").trim(); + const targetSectionHeader = content.includes(RECALL_TARGET_CONTENT_HEADER) + ? RECALL_TARGET_CONTENT_HEADER + : content.includes(EXTRACTION_TARGET_CONTENT_HEADER) + ? EXTRACTION_TARGET_CONTENT_HEADER + : ""; + if ( + normalizedRole !== "system" || + !["recentMessages", "dialogueText"].includes(sourceKey) || + !content.includes(EXTRACTION_CONTEXT_REVIEW_HEADER) || + !targetSectionHeader + ) { + return [message]; + } + + const headerMatches = []; + let searchIndex = 0; + while (searchIndex < content.length) { + const contextIndex = content.indexOf( + EXTRACTION_CONTEXT_REVIEW_HEADER, + searchIndex, + ); + const targetIndex = targetSectionHeader + ? content.indexOf(targetSectionHeader, searchIndex) + : -1; + let nextIndex = -1; + let nextHeader = ""; + if (contextIndex >= 0 && (targetIndex < 0 || contextIndex <= targetIndex)) { + nextIndex = contextIndex; + nextHeader = EXTRACTION_CONTEXT_REVIEW_HEADER; + } else if (targetIndex >= 0) { + nextIndex = targetIndex; + nextHeader = targetSectionHeader; + } + if (nextIndex < 0 || !nextHeader) { + break; + } + headerMatches.push({ + index: nextIndex, + header: nextHeader, + }); + searchIndex = nextIndex + nextHeader.length; + } + + if (headerMatches.length < 2 || headerMatches[0].index !== 0) { + return [message]; + } + + const { role: _role, content: _content, ...sharedMeta } = message; + const splitMessages = []; + + for (let index = 0; index < headerMatches.length; index += 1) { + const current = headerMatches[index]; + const next = headerMatches[index + 1]; + const sectionBody = content + .slice(current.index + current.header.length, next ? next.index : content.length) + .trim(); + const transcriptSection = + current.header === EXTRACTION_CONTEXT_REVIEW_HEADER ? "context" : "target"; + splitMessages.push( + createExecutionMessage( + "system", + sectionBody ? `${current.header}\n\n${sectionBody}` : current.header, + { + ...sharedMeta, + sourceKey, + transcriptSection, + transcriptSectionPart: "section", + }, + ), + ); + } + + return splitMessages.filter(Boolean); +} + +function expandSectionedTranscriptPayloadMessages(messages = []) { + return (Array.isArray(messages) ? messages : []).flatMap((message) => + splitSectionedTranscriptPayloadMessage(message), + ); +} + function collectPayloadUserMessageTexts(messages = []) { return (Array.isArray(messages) ? messages : []) .filter((message) => String(message?.role || "").trim().toLowerCase() === "user") @@ -1978,8 +2108,11 @@ export function buildTaskLlmPayload(promptBuild = null, fallbackUserPrompt = "") !(isCustomFilter && messageUsesWorldInfoContent(message)), }, ); + const expandedExecutionMessages = expandSectionedTranscriptPayloadMessages( + executionMessages, + ); - const hasUserMessage = executionMessages.some( + const hasUserMessage = expandedExecutionMessages.some( (message) => message.role === "user", ); if (!hasUserMessage && rawExecutionMessages.length > 0) { @@ -1998,7 +2131,7 @@ export function buildTaskLlmPayload(promptBuild = null, fallbackUserPrompt = "") `after recreate=${userBlocksAfterRaw.length}, ` + `after sanitize=${userBlocksAfterSanitize.length}, ` + `blockedContents count=${blockedContents.length}, ` + - `total executionMessages=${executionMessages.length}`, + `total executionMessages=${expandedExecutionMessages.length}`, ); if (userBlocksBefore.length > 0) { for (const block of userBlocksBefore) { @@ -2016,17 +2149,19 @@ export function buildTaskLlmPayload(promptBuild = null, fallbackUserPrompt = "") } } const additionalMessages = - executionMessages.length > 0 + expandedExecutionMessages.length > 0 ? [] - : sanitizePromptMessages( - settings, - taskType, - rawPrivateTaskMessages, - { - blockedContents, - applySanitizer: (message) => - !(isCustomFilter && messageUsesWorldInfoContent(message)), - }, + : expandSectionedTranscriptPayloadMessages( + sanitizePromptMessages( + settings, + taskType, + rawPrivateTaskMessages, + { + blockedContents, + applySanitizer: (message) => + !(isCustomFilter && messageUsesWorldInfoContent(message)), + }, + ), ); const hasAdditionalUserMessage = additionalMessages.some( (message) => message.role === "user", @@ -2048,9 +2183,11 @@ export function buildTaskLlmPayload(promptBuild = null, fallbackUserPrompt = "") return { systemPrompt: - executionMessages.length > 0 ? "" : String(promptBuild?.systemPrompt || ""), + expandedExecutionMessages.length > 0 + ? "" + : String(promptBuild?.systemPrompt || ""), userPrompt: fallbackUserPromptResult.text, - promptMessages: executionMessages, + promptMessages: expandedExecutionMessages, additionalMessages, fallbackUserPromptSource: fallbackUserPromptResult.source, fallbackUserPromptApplied: Boolean(fallbackUserPromptResult.text), diff --git a/prompting/prompt-node-references.js b/prompting/prompt-node-references.js new file mode 100644 index 0000000..26961fd --- /dev/null +++ b/prompting/prompt-node-references.js @@ -0,0 +1,93 @@ +import { truncateNodeLabel } from "../graph/node-labels.js"; + +function normalizePromptNodeText(value) { + return String(value ?? "") + .replace(/\s+/g, " ") + .trim(); +} + +function resolvePromptNode(value = {}) { + if (value?.node && typeof value.node === "object") { + return value.node; + } + return value && typeof value === "object" ? value : {}; +} + +export function resolvePromptNodeId(value = {}) { + const node = resolvePromptNode(value); + return String(value?.nodeId || node?.id || "").trim(); +} + +export function getPromptNodeLabel(value = {}, { maxLength = 32 } = {}) { + const node = resolvePromptNode(value); + const fallbackId = typeof node?.id === "string" ? node.id.slice(0, 8) : ""; + const rawLabel = normalizePromptNodeText( + node?.fields?.title || + node?.fields?.name || + node?.fields?.summary || + node?.fields?.insight || + node?.fields?.belief || + node?.name || + fallbackId || + "—", + ); + return truncateNodeLabel(rawLabel || "—", maxLength); +} + +export function createPromptNodeReferenceMap( + entries = [], + { + prefix = "N", + maxLength = 32, + buildMeta = null, + } = {}, +) { + const keyToNodeId = {}; + const keyToMeta = {}; + const nodeIdToKey = {}; + const references = []; + + for (const [index, entry] of (Array.isArray(entries) ? entries : []).entries()) { + const node = resolvePromptNode(entry); + const nodeId = resolvePromptNodeId(entry); + if (!nodeId || nodeIdToKey[nodeId]) { + continue; + } + + const key = `${String(prefix || "N").trim() || "N"}${references.length + 1}`; + const label = getPromptNodeLabel(node, { maxLength }); + const extraMeta = typeof buildMeta === "function" + ? buildMeta({ + entry, + node, + nodeId, + key, + index, + label, + }) + : {}; + + keyToNodeId[key] = nodeId; + nodeIdToKey[nodeId] = key; + keyToMeta[key] = { + nodeId, + type: String(node?.type || ""), + label, + ...(extraMeta && typeof extraMeta === "object" ? extraMeta : {}), + }; + references.push({ + key, + nodeId, + node, + meta: keyToMeta[key], + }); + } + + return { + prefix: String(prefix || "N").trim() || "N", + references, + keyToNodeId, + keyToMeta, + nodeIdToKey, + }; +} diff --git a/retrieval/retriever.js b/retrieval/retriever.js index 17d03be..ea97f90 100644 --- a/retrieval/retriever.js +++ b/retrieval/retriever.js @@ -16,6 +16,8 @@ import { buildTaskExecutionDebugContext, buildTaskLlmPayload, buildTaskPrompt, + EXTRACTION_CONTEXT_REVIEW_HEADER, + RECALL_TARGET_CONTENT_HEADER, } from "../prompting/prompt-builder.js"; import { applyCooccurrenceBoost, @@ -23,9 +25,7 @@ import { collectSupplementalAnchorNodeIds, createCooccurrenceIndex, isEligibleAnchorNode, - mergeVectorResults, runResidualRecall, - splitIntentSegments, } from "./retrieval-enhancer.js"; import { MEMORY_SCOPE_BUCKETS, @@ -36,6 +36,7 @@ import { normalizeMemoryScope, resolveScopeBucketWeight, } from "../graph/memory-scope.js"; +import { rankNodesForTaskContext } from "./shared-ranking.js"; import { computeKnowledgeGateForNode, listKnowledgeOwners, @@ -54,8 +55,8 @@ import { } from "../graph/story-timeline.js"; import { getActiveSummaryEntries } from "../graph/summary-state.js"; import { applyTaskRegex } from "../prompting/task-regex.js"; +import { createPromptNodeReferenceMap } from "../prompting/prompt-node-references.js"; import { getSTContextForPrompt } from "../host/st-context.js"; -import { findSimilarNodesByText, validateVectorConfig } from "../vector/vector-index.js"; function createAbortError(message = "操作已终止") { const error = new Error(message); @@ -94,6 +95,32 @@ function resolveTaskLlmSystemPrompt(promptPayload, fallbackSystemPrompt = "") { return String(promptPayload?.systemPrompt || fallbackSystemPrompt || ""); } +function buildRecallSectionedTranscript(recentMessages = []) { + const lines = (Array.isArray(recentMessages) ? recentMessages : []) + .map((line) => String(line || "").trim()) + .filter(Boolean); + if (lines.length === 0) { + return ""; + } + + const targetLines = [lines[lines.length - 1]].filter(Boolean); + const contextLines = lines.slice(0, -1).filter(Boolean); + const sections = []; + + if (contextLines.length > 0) { + sections.push( + `${EXTRACTION_CONTEXT_REVIEW_HEADER}\n\n${contextLines.join("\n---\n")}`, + ); + } + if (targetLines.length > 0) { + sections.push( + `${RECALL_TARGET_CONTENT_HEADER}\n\n${targetLines.join("\n---\n")}`, + ); + } + + return sections.join("\n\n"); +} + function buildRecallFallbackReason(llmResult) { const failureType = String(llmResult?.errorType || "").trim(); const failureReason = String(llmResult?.failureReason || "").trim(); @@ -241,14 +268,6 @@ function normalizeQueryText(value, maxLength = 400) { return normalized.slice(0, Math.max(1, maxLength)); } -function createTextPreview(text, maxLength = 120) { - const normalized = normalizeQueryText(text, maxLength + 4); - if (!normalized) return ""; - return normalized.length > maxLength - ? `${normalized.slice(0, maxLength)}...` - : normalized; -} - function normalizeRecallSelectionList(values = [], maxLength = 64) { const normalized = []; const seen = new Set(); @@ -262,262 +281,23 @@ function normalizeRecallSelectionList(values = [], maxLength = 64) { return normalized; } -function getRecallCandidateLabel(node = {}) { - return String( - node?.fields?.title || - node?.fields?.name || - node?.fields?.summary || - node?.fields?.insight || - node?.fields?.belief || - node?.id || - "", - ).trim(); -} - function createRecallCandidateKeyMaps(candidates = []) { - const candidateKeyToNodeId = {}; - const candidateKeyToCandidateMeta = {}; - const nodeIdToCandidateKey = {}; - - for (const [index, candidate] of (Array.isArray(candidates) ? candidates : []).entries()) { - const node = candidate?.node || {}; - const nodeId = String(candidate?.nodeId || node?.id || "").trim(); - if (!nodeId) continue; - const candidateKey = `R${index + 1}`; - candidateKeyToNodeId[candidateKey] = nodeId; - nodeIdToCandidateKey[nodeId] = candidateKey; - candidateKeyToCandidateMeta[candidateKey] = { - nodeId, - type: String(node?.type || ""), - label: getRecallCandidateLabel(node), - scopeBucket: String(candidate?.scopeBucket || ""), - temporalBucket: String(candidate?.temporalBucket || ""), + const referenceMap = createPromptNodeReferenceMap(candidates, { + prefix: "R", + maxLength: 80, + buildMeta: ({ entry }) => ({ + scopeBucket: String(entry?.scopeBucket || ""), + temporalBucket: String(entry?.temporalBucket || ""), score: Math.round( - (Number(candidate?.weightedScore ?? candidate?.finalScore) || 0) * 1000, + (Number(entry?.weightedScore ?? entry?.finalScore) || 0) * 1000, ) / 1000, - }; - } - + }), + }); return { - candidateKeyToNodeId, - candidateKeyToCandidateMeta, - nodeIdToCandidateKey, - }; -} - -function roundBlendWeight(value) { - return Math.round((Number(value) || 0) * 1000) / 1000; -} - -function uniqueStrings(values = [], maxLength = 400) { - const result = []; - const seen = new Set(); - - for (const value of values) { - const text = normalizeQueryText(value, maxLength); - const key = text.toLowerCase(); - if (!text || seen.has(key)) continue; - seen.add(key); - result.push(text); - } - - return result; -} - -function parseRecallContextLine(line = "") { - const raw = String(line ?? "").trim(); - if (!raw) return null; - - const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i); - if (bracketMatch) { - const role = String(bracketMatch[1] || "").toLowerCase(); - const text = normalizeQueryText(bracketMatch[2] || ""); - return text ? { role, text } : null; - } - - const plainMatch = raw.match( - /^(user|assistant|用户|助手|ai)\s*[::]\s*([\s\S]*)$/i, - ); - if (!plainMatch) return null; - - const roleToken = String(plainMatch[1] || "").toLowerCase(); - const role = - roleToken === "assistant" || roleToken === "助手" || roleToken === "ai" - ? "assistant" - : "user"; - const text = normalizeQueryText(plainMatch[2] || ""); - return text ? { role, text } : null; -} - -function buildContextQueryBlend( - userMessage, - recentMessages = [], - { - enabled = true, - assistantWeight = 0.2, - previousUserWeight = 0.1, - maxTextLength = 400, - } = {}, -) { - const currentText = normalizeQueryText(userMessage, maxTextLength); - const normalizedAssistantWeight = clampRange(assistantWeight, 0.2, 0, 1); - const normalizedPreviousUserWeight = clampRange( - previousUserWeight, - 0.1, - 0, - 1, - ); - const currentWeight = Math.max( - 0, - 1 - normalizedAssistantWeight - normalizedPreviousUserWeight, - ); - - let assistantText = ""; - let previousUserText = ""; - const parsedMessages = Array.isArray(recentMessages) - ? recentMessages.map((line) => parseRecallContextLine(line)).filter(Boolean) - : []; - - for (let index = parsedMessages.length - 1; index >= 0; index--) { - const item = parsedMessages[index]; - if (!assistantText && item.role === "assistant") { - assistantText = normalizeQueryText(item.text, maxTextLength); - } - if ( - !previousUserText && - item.role === "user" && - normalizeQueryText(item.text, maxTextLength).toLowerCase() !== - currentText.toLowerCase() - ) { - previousUserText = normalizeQueryText(item.text, maxTextLength); - } - if (assistantText && previousUserText) break; - } - - const rawParts = [ - { - kind: "currentUser", - label: "当前用户消息", - text: currentText, - weight: enabled ? currentWeight : 1, - }, - ]; - - if (enabled && assistantText) { - rawParts.push({ - kind: "assistantContext", - label: "最近 assistant 回复", - text: assistantText, - weight: normalizedAssistantWeight, - }); - } - - if (enabled && previousUserText) { - rawParts.push({ - kind: "previousUser", - label: "上一条 user 消息", - text: previousUserText, - weight: normalizedPreviousUserWeight, - }); - } - - const dedupedParts = []; - const seen = new Set(); - for (const part of rawParts) { - const text = normalizeQueryText(part.text, maxTextLength); - const key = text.toLowerCase(); - if (!text || seen.has(key)) continue; - seen.add(key); - dedupedParts.push({ - ...part, - text, - }); - } - - if (dedupedParts.length === 0) { - return { - active: false, - parts: [], - currentText: "", - assistantText: "", - previousUserText: "", - combinedText: "", - }; - } - - const totalWeight = dedupedParts.reduce( - (sum, part) => sum + Math.max(0, Number(part.weight) || 0), - 0, - ); - const normalizedParts = dedupedParts.map((part) => ({ - ...part, - weight: - totalWeight > 0 - ? roundBlendWeight((Math.max(0, Number(part.weight) || 0) || 0) / totalWeight) - : roundBlendWeight(1 / dedupedParts.length), - })); - const combinedText = - normalizedParts.length <= 1 - ? normalizedParts[0]?.text || "" - : normalizedParts - .map((part) => `${part.label}:\n${part.text}`) - .join("\n\n"); - - return { - active: enabled && normalizedParts.length > 1, - parts: normalizedParts, - currentText: currentText || normalizedParts[0]?.text || "", - assistantText, - previousUserText, - combinedText, - }; -} - -function buildVectorQueryPlan( - blendPlan, - { enableMultiIntent = true, maxSegments = 4 } = {}, -) { - const plan = []; - let currentSegments = []; - - for (const part of blendPlan?.parts || []) { - let queries = [part.text]; - if (part.kind === "currentUser" && enableMultiIntent) { - currentSegments = splitIntentSegments(part.text, { maxSegments }); - queries = uniqueStrings([ - part.text, - ...currentSegments.filter((item) => item !== part.text), - ]); - } else { - queries = uniqueStrings([part.text]); - } - - plan.push({ - kind: part.kind, - label: part.label, - weight: part.weight, - queries, - }); - } - - return { - plan, - currentSegments, - }; -} - -function buildLexicalQuerySources( - userMessage, - { enableMultiIntent = true, maxSegments = 4 } = {}, -) { - const currentText = normalizeQueryText(userMessage, 400); - const segments = enableMultiIntent - ? splitIntentSegments(currentText, { maxSegments }) - : []; - return { - sources: uniqueStrings([currentText, ...segments]), - segments, + candidateKeyToNodeId: referenceMap.keyToNodeId, + candidateKeyToCandidateMeta: referenceMap.keyToMeta, + nodeIdToCandidateKey: referenceMap.nodeIdToKey, }; } @@ -722,13 +502,6 @@ function buildVisibilityTopHits(scoredNodes = [], maxCount = 6) { })); } -function scaleVectorResults(results = [], weight = 1) { - return (Array.isArray(results) ? results : []).map((item) => ({ - ...item, - score: (Number(item?.score) || 0) * Math.max(0, Number(weight) || 0), - })); -} - function pickActiveRegion(graph, optionValue = "") { const direct = String(optionValue || "").trim(); if (direct) return direct; @@ -1462,7 +1235,6 @@ export async function retrieve({ normalizedMaxRecallNodes, llmCandidatePool, ); - const vectorValidation = validateVectorConfig(embeddingConfig); const retrievalMeta = createRetrievalMeta(enableLLMRecall); retrievalMeta.activeRegion = activeRegion; retrievalMeta.activeRegionSource = activeRegionContext.source || ""; @@ -1490,29 +1262,6 @@ export async function retrieve({ retrievalMeta.knowledgeGateMode = enableCognitiveMemory ? "anchored-soft-visibility" : "disabled"; - const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, { - enabled: enableContextQueryBlend, - assistantWeight: contextAssistantWeight, - previousUserWeight: contextPreviousUserWeight, - }); - retrievalMeta.queryBlendActive = contextQueryBlend.active; - retrievalMeta.queryBlendParts = (contextQueryBlend.parts || []).map((part) => ({ - kind: part.kind, - label: part.label, - weight: part.weight, - text: createTextPreview(part.text), - length: part.text.length, - })); - retrievalMeta.queryBlendWeights = Object.fromEntries( - (contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]), - ); - const lexicalQuery = buildLexicalQuerySources( - contextQueryBlend.currentText || userMessage, - { - enableMultiIntent, - maxSegments: multiIntentMaxSegments, - }, - ); debugLog( `[ST-BME] 检索开始: ${nodeCount} 个活跃节点${enableVisibility ? " (认知边界已启用)" : ""}`, ); @@ -1567,49 +1316,85 @@ export async function retrieve({ }, }); } - - const vectorStartedAt = nowMs(); - if (enableVectorPrefilter && vectorValidation.valid) { - debugLog("[ST-BME] 第1层: 向量预筛"); - const queryPlan = buildVectorQueryPlan(contextQueryBlend, { + const sharedRanking = await rankNodesForTaskContext({ + graph, + userMessage, + recentMessages, + embeddingConfig, + signal, + options: { + topK: normalizedTopK, + diffusionTopK: normalizedDiffusionTopK, + enableVectorPrefilter, + enableGraphDiffusion, + enableContextQueryBlend, enableMultiIntent, - maxSegments: multiIntentMaxSegments, - }); - const groups = []; - - retrievalMeta.segmentsUsed = queryPlan.currentSegments; - for (const part of queryPlan.plan) { - for (const queryText of part.queries) { - const results = await vectorPreFilter( - graph, - queryText, - activeNodes, - embeddingConfig, - normalizedTopK, - signal, - ); - groups.push(scaleVectorResults(results, part.weight || 1)); - } - } - - const merged = mergeVectorResults( - groups, - Math.max(normalizedTopK * 2, 24), - ); - retrievalMeta.vectorHits = merged.rawHitCount; - retrievalMeta.vectorMergedHits = merged.results.length; - vectorResults = merged.results; - } else if (enableVectorPrefilter) { - pushSkipReason(retrievalMeta, "vector-config-invalid"); - } - retrievalMeta.timings.vector = roundMs(nowMs() - vectorStartedAt); - - exactEntityAnchors.push( - ...extractEntityAnchors( - contextQueryBlend.currentText || userMessage, + multiIntentMaxSegments, + contextAssistantWeight, + contextPreviousUserWeight, + teleportAlpha, + enableTemporalLinks, + temporalLinkStrength, + enableLexicalBoost, + lexicalWeight, + weights, activeNodes, - ), + }, + }); + const contextQueryBlend = sharedRanking.contextQueryBlend; + const lexicalQuery = sharedRanking.lexicalQuery; + retrievalMeta.queryBlendActive = Boolean( + sharedRanking?.diagnostics?.queryBlendActive, ); + retrievalMeta.queryBlendParts = Array.isArray( + sharedRanking?.diagnostics?.queryBlendParts, + ) + ? [...sharedRanking.diagnostics.queryBlendParts] + : []; + retrievalMeta.queryBlendWeights = { + ...(sharedRanking?.diagnostics?.queryBlendWeights || {}), + }; + retrievalMeta.segmentsUsed = Array.isArray(sharedRanking?.diagnostics?.segmentsUsed) + ? [...sharedRanking.diagnostics.segmentsUsed] + : []; + retrievalMeta.vectorHits = Number(sharedRanking?.diagnostics?.vectorHits || 0); + retrievalMeta.vectorMergedHits = Number( + sharedRanking?.diagnostics?.vectorMergedHits || 0, + ); + retrievalMeta.seedCount = Number(sharedRanking?.diagnostics?.seedCount || 0); + retrievalMeta.diffusionHits = Number( + sharedRanking?.diagnostics?.diffusionHits || 0, + ); + retrievalMeta.lexicalBoostedNodes = Number( + sharedRanking?.diagnostics?.lexicalBoostedNodes || 0, + ); + retrievalMeta.temporalSyntheticEdgeCount = Number( + sharedRanking?.diagnostics?.temporalSyntheticEdgeCount || 0, + ); + retrievalMeta.teleportAlpha = Number( + sharedRanking?.diagnostics?.teleportAlpha || teleportAlpha, + ); + retrievalMeta.lexicalTopHits = Array.isArray( + sharedRanking?.diagnostics?.lexicalTopHits, + ) + ? [...sharedRanking.diagnostics.lexicalTopHits] + : []; + retrievalMeta.timings.vector = Number( + sharedRanking?.diagnostics?.timings?.vector || 0, + ); + retrievalMeta.timings.diffusion = Number( + sharedRanking?.diagnostics?.timings?.diffusion || 0, + ); + for (const reason of sharedRanking?.diagnostics?.skipReasons || []) { + pushSkipReason(retrievalMeta, reason); + } + vectorResults = Array.isArray(sharedRanking?.vectorResults) + ? [...sharedRanking.vectorResults] + : []; + diffusionResults = Array.isArray(sharedRanking?.diffusionResults) + ? [...sharedRanking.diffusionResults] + : []; + exactEntityAnchors.push(...(sharedRanking?.exactEntityAnchors || [])); supplementalAnchorNodeIds = collectSupplementalAnchorNodeIds( graph, vectorResults, @@ -1650,7 +1435,7 @@ export async function retrieve({ retrievalMeta.timings.residual = roundMs(nowMs() - residualStartedAt); const diffusionStartedAt = nowMs(); - if (enableGraphDiffusion) { + if (enableGraphDiffusion && (enableCrossRecall || residualResult.triggered)) { debugLog("[ST-BME] 第2层: PEDSA 图扩散"); const seeds = [ ...vectorResults.map((v) => ({ id: v.nodeId, energy: v.score })), @@ -1705,9 +1490,11 @@ export async function retrieve({ return node && !node.archived; }); } + retrievalMeta.diffusionHits = diffusionResults.length; + } + if (enableGraphDiffusion && (enableCrossRecall || residualResult.triggered)) { + retrievalMeta.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); } - retrievalMeta.diffusionHits = diffusionResults.length; - retrievalMeta.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); debugLog("[ST-BME] 第3层: 混合评分"); @@ -2259,7 +2046,9 @@ export async function retrieve({ retrievalMeta.timings.total = roundMs(nowMs() - startedAt); return buildResult(graph, selectedNodeIds, schema, { - retrieval: retrievalMeta, + retrieval: { + ...retrievalMeta, + }, scopeContext: { enableScopedMemory, enablePovMemory, @@ -2295,62 +2084,6 @@ export async function retrieve({ }); } -/** - * 向量预筛选 - */ -async function vectorPreFilter( - graph, - userMessage, - activeNodes, - embeddingConfig, - topK, - signal, -) { - try { - return await findSimilarNodesByText( - graph, - userMessage, - embeddingConfig, - topK, - activeNodes, - signal, - ); - } catch (e) { - if (isAbortError(e)) { - throw e; - } - console.error("[ST-BME] 向量预筛失败:", e); - return []; - } -} - -/** - * 实体锚点提取 - * 从用户消息中提取名词/实体,匹配图中的节点名称 - */ -function extractEntityAnchors(userMessage, activeNodes) { - const anchors = []; - const seen = new Set(); - - for (const node of activeNodes) { - const candidates = [node.fields?.name, node.fields?.title] - .filter((value) => typeof value === "string") - .map((value) => value.trim()) - .filter((value) => value.length >= 2); - - for (const candidate of candidates) { - if (!userMessage.includes(candidate)) continue; - const key = `${node.id}:${candidate}`; - if (seen.has(key)) continue; - seen.add(key); - anchors.push({ nodeId: node.id, entity: candidate }); - break; - } - } - - return anchors; -} - function buildResidualBasisNodes( graph, exactEntityAnchors, @@ -2448,6 +2181,8 @@ async function llmRecall( ) { throwIfAborted(signal); const contextStr = recentMessages.join("\n---\n"); + const sectionedContextStr = + buildRecallSectionedTranscript(recentMessages) || contextStr; const sceneOwnerCandidateText = buildSceneOwnerCandidateText(sceneOwnerCandidates); const { candidateKeyToNodeId, @@ -2463,14 +2198,16 @@ async function llmRecall( const fieldsStr = Object.entries(node.fields) .map(([k, v]) => `${k}: ${v}`) .join(", "); - const candidateKey = `R${index + 1}`; + const candidateKey = + nodeIdToCandidateKey[String(c?.nodeId || node?.id || "").trim()] || + `R${index + 1}`; return `[${candidateKey}] 类型=${typeLabel}, 作用域=${describeMemoryScope(node.scope)}, 时间=${storyTimeLabel || "未标注"}, 时间桶=${String(c.temporalBucket || STORY_TEMPORAL_BUCKETS.UNDATED)}, 召回桶=${describeScopeBucket(c.scopeBucket)}, 认知=${String(c.knowledgeMode || "unknown")}, 可见性=${(Number(c.knowledgeVisibilityScore) || 0).toFixed(3)}, ${fieldsStr} (评分=${(c.weightedScore ?? c.finalScore).toFixed(3)})`; }) .join("\n"); const recallPromptBuild = await buildTaskPrompt(settings, "recall", { taskName: "recall", - recentMessages: contextStr || "(无)", + recentMessages: sectionedContextStr || "(无)", userMessage, candidateNodes: candidateDescriptions, candidateText: candidateDescriptions, @@ -2505,7 +2242,7 @@ async function llmRecall( activeStoryTimeLabel || "(未确定)", "", "## 最近对话上下文", - contextStr || "(无)", + sectionedContextStr || contextStr || "(无)", "", "## 用户最新输入", userMessage, diff --git a/retrieval/shared-ranking.js b/retrieval/shared-ranking.js new file mode 100644 index 0000000..6e1fba3 --- /dev/null +++ b/retrieval/shared-ranking.js @@ -0,0 +1,752 @@ +import { buildTemporalAdjacencyMap, getActiveNodes, getNode } from "../graph/graph.js"; +import { findSimilarNodesByText, validateVectorConfig } from "../vector/vector-index.js"; +import { hybridScore } from "./dynamics.js"; +import { diffuseAndRank } from "./diffusion.js"; +import { mergeVectorResults, splitIntentSegments } from "./retrieval-enhancer.js"; + +function nowMs() { + if (typeof performance?.now === "function") { + return performance.now(); + } + return Date.now(); +} + +function roundMs(value) { + return Math.round((Number(value) || 0) * 10) / 10; +} + +export function clampPositiveInt(value, fallback, min = 1) { + const parsed = Math.floor(Number(value)); + return Number.isFinite(parsed) && parsed >= min ? parsed : fallback; +} + +export function clampRange(value, fallback, min = 0, max = 1) { + const parsed = Number(value); + if (!Number.isFinite(parsed)) return fallback; + return Math.max(min, Math.min(max, parsed)); +} + +export function normalizeQueryText(value, maxLength = 400) { + const normalized = String(value ?? "") + .replace(/\r\n/g, "\n") + .replace(/\s+/g, " ") + .trim(); + if (!normalized) return ""; + return normalized.slice(0, Math.max(1, maxLength)); +} + +export function createTextPreview(text, maxLength = 120) { + const normalized = normalizeQueryText(text, maxLength + 4); + if (!normalized) return ""; + return normalized.length > maxLength + ? `${normalized.slice(0, maxLength)}...` + : normalized; +} + +function uniqueStrings(values = [], maxLength = 400) { + const result = []; + const seen = new Set(); + + for (const value of values) { + const text = normalizeQueryText(value, maxLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + result.push(text); + } + + return result; +} + +function parseContextLine(line = "") { + const raw = String(line ?? "").trim(); + if (!raw) return null; + + const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i); + if (bracketMatch) { + const role = String(bracketMatch[1] || "").toLowerCase(); + const text = normalizeQueryText(bracketMatch[2] || ""); + return text ? { role, text } : null; + } + + const plainMatch = raw.match(/^(user|assistant|用户|助手|ai)\s*[::]\s*([\s\S]*)$/i); + if (!plainMatch) return null; + + const roleToken = String(plainMatch[1] || "").toLowerCase(); + const role = + roleToken === "assistant" || roleToken === "助手" || roleToken === "ai" + ? "assistant" + : "user"; + const text = normalizeQueryText(plainMatch[2] || ""); + return text ? { role, text } : null; +} + +export function buildContextQueryBlend( + userMessage, + recentMessages = [], + { + enabled = true, + assistantWeight = 0.2, + previousUserWeight = 0.1, + maxTextLength = 400, + } = {}, +) { + const currentText = normalizeQueryText(userMessage, maxTextLength); + const normalizedAssistantWeight = clampRange(assistantWeight, 0.2, 0, 1); + const normalizedPreviousUserWeight = clampRange( + previousUserWeight, + 0.1, + 0, + 1, + ); + const currentWeight = Math.max( + 0, + 1 - normalizedAssistantWeight - normalizedPreviousUserWeight, + ); + + let assistantText = ""; + let previousUserText = ""; + const parsedMessages = Array.isArray(recentMessages) + ? recentMessages.map((line) => parseContextLine(line)).filter(Boolean) + : []; + + for (let index = parsedMessages.length - 1; index >= 0; index -= 1) { + const item = parsedMessages[index]; + if (!assistantText && item.role === "assistant") { + assistantText = normalizeQueryText(item.text, maxTextLength); + } + if ( + !previousUserText && + item.role === "user" && + normalizeQueryText(item.text, maxTextLength).toLowerCase() !== + currentText.toLowerCase() + ) { + previousUserText = normalizeQueryText(item.text, maxTextLength); + } + if (assistantText && previousUserText) break; + } + + const rawParts = [ + { + kind: "currentUser", + label: "当前用户消息", + text: currentText, + weight: enabled ? currentWeight : 1, + }, + ]; + + if (enabled && assistantText) { + rawParts.push({ + kind: "assistantContext", + label: "最近 assistant 回复", + text: assistantText, + weight: normalizedAssistantWeight, + }); + } + + if (enabled && previousUserText) { + rawParts.push({ + kind: "previousUser", + label: "上一条 user 消息", + text: previousUserText, + weight: normalizedPreviousUserWeight, + }); + } + + const dedupedParts = []; + const seen = new Set(); + for (const part of rawParts) { + const text = normalizeQueryText(part.text, maxTextLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + dedupedParts.push({ + ...part, + text, + }); + } + + if (dedupedParts.length === 0) { + return { + active: false, + parts: [], + currentText: "", + assistantText: "", + previousUserText: "", + combinedText: "", + }; + } + + const totalWeight = dedupedParts.reduce( + (sum, part) => sum + Math.max(0, Number(part.weight) || 0), + 0, + ); + const normalizedParts = dedupedParts.map((part) => ({ + ...part, + weight: + totalWeight > 0 + ? Math.round( + ((Math.max(0, Number(part.weight) || 0) || 0) / totalWeight) * 1000, + ) / 1000 + : Math.round((1 / dedupedParts.length) * 1000) / 1000, + })); + const combinedText = + normalizedParts.length <= 1 + ? normalizedParts[0]?.text || "" + : normalizedParts + .map((part) => `${part.label}:\n${part.text}`) + .join("\n\n"); + + return { + active: enabled && normalizedParts.length > 1, + parts: normalizedParts, + currentText: currentText || normalizedParts[0]?.text || "", + assistantText, + previousUserText, + combinedText, + }; +} + +export function buildVectorQueryPlan( + blendPlan, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const plan = []; + let currentSegments = []; + + for (const part of blendPlan?.parts || []) { + let queries = [part.text]; + if (part.kind === "currentUser" && enableMultiIntent) { + currentSegments = splitIntentSegments(part.text, { maxSegments }); + queries = uniqueStrings([ + part.text, + ...currentSegments.filter((item) => item !== part.text), + ]); + } else { + queries = uniqueStrings([part.text]); + } + + plan.push({ + kind: part.kind, + label: part.label, + weight: part.weight, + queries, + }); + } + + return { + plan, + currentSegments, + }; +} + +export function buildLexicalQuerySources( + userMessage, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const currentText = normalizeQueryText(userMessage, 400); + const segments = enableMultiIntent + ? splitIntentSegments(currentText, { maxSegments }) + : []; + return { + sources: uniqueStrings([currentText, ...segments]), + segments, + }; +} + +function normalizeLexicalText(value = "") { + return normalizeQueryText(value, 600).toLowerCase(); +} + +function buildLexicalUnits(text = "") { + const normalized = normalizeLexicalText(text); + if (!normalized) return []; + + const rawTokens = normalized.match(/[a-z0-9]+|[\u4e00-\u9fff]+/g) || []; + const units = []; + + for (const token of rawTokens) { + if (token.length >= 2) { + units.push(token); + } + if (/[\u4e00-\u9fff]/.test(token) && token.length > 2) { + for (let index = 0; index < token.length - 1; index += 1) { + units.push(token.slice(index, index + 2)); + } + } + } + + return [...new Set(units)]; +} + +function computeTokenOverlapScore(sourceUnits = [], targetUnits = []) { + if (!sourceUnits.length || !targetUnits.length) return 0; + const targetSet = new Set(targetUnits); + let overlap = 0; + for (const unit of sourceUnits) { + if (targetSet.has(unit)) { + overlap += 1; + } + } + return overlap / Math.max(1, sourceUnits.length); +} + +function scoreFieldMatch( + fieldText, + querySources = [], + { exact = 1, includes = 0.9, overlap = 0.6 } = {}, +) { + const normalizedField = normalizeLexicalText(fieldText); + if (!normalizedField) return 0; + + const fieldUnits = buildLexicalUnits(normalizedField); + let best = 0; + + for (const sourceText of querySources) { + const normalizedSource = normalizeLexicalText(sourceText); + if (!normalizedSource) continue; + + if (normalizedSource === normalizedField) { + best = Math.max(best, exact); + continue; + } + + if ( + Math.min(normalizedSource.length, normalizedField.length) >= 2 && + (normalizedSource.includes(normalizedField) || + normalizedField.includes(normalizedSource)) + ) { + best = Math.max(best, includes); + } + + const overlapScore = computeTokenOverlapScore( + buildLexicalUnits(normalizedSource), + fieldUnits, + ); + best = Math.max(best, overlapScore * overlap); + } + + return Math.min(1, best); +} + +function collectNodeLexicalTexts(node, fieldNames = []) { + const values = []; + for (const fieldName of fieldNames) { + const value = node?.fields?.[fieldName]; + if (typeof value === "string" && value.trim()) { + values.push(value.trim()); + } else if (Array.isArray(value)) { + for (const item of value) { + if (typeof item === "string" && item.trim()) { + values.push(item.trim()); + } + } + } + } + return values; +} + +export function computeLexicalScore(node, querySources = []) { + if (!node || !Array.isArray(querySources) || querySources.length === 0) { + return 0; + } + + const primaryTexts = collectNodeLexicalTexts(node, ["name", "title"]); + const secondaryTexts = collectNodeLexicalTexts(node, [ + "summary", + "insight", + "state", + "traits", + "participants", + "status", + ]); + const combinedText = [...primaryTexts, ...secondaryTexts].join(" "); + + const primaryScore = primaryTexts.reduce( + (best, value) => + Math.max( + best, + scoreFieldMatch(value, querySources, { + exact: 1, + includes: 0.92, + overlap: 0.72, + }), + ), + 0, + ); + const secondaryScore = secondaryTexts.reduce( + (best, value) => + Math.max( + best, + scoreFieldMatch(value, querySources, { + exact: 0.82, + includes: 0.68, + overlap: 0.52, + }), + ), + 0, + ); + const tokenScore = scoreFieldMatch(combinedText, querySources, { + exact: 0.65, + includes: 0.55, + overlap: 0.45, + }); + + if (primaryScore <= 0 && secondaryScore <= 0 && tokenScore <= 0) { + return 0; + } + + return Math.min( + 1, + Math.max( + primaryScore, + secondaryScore * 0.82, + tokenScore * 0.7, + primaryScore * 0.75 + secondaryScore * 0.35 + tokenScore * 0.2, + ), + ); +} + +export function scaleVectorResults(results = [], weight = 1) { + return (Array.isArray(results) ? results : []).map((item) => ({ + ...item, + score: (Number(item?.score) || 0) * Math.max(0, Number(weight) || 0), + })); +} + +function isAbortError(error) { + return error?.name === "AbortError"; +} + +export async function vectorPreFilter( + graph, + userMessage, + activeNodes, + embeddingConfig, + topK, + signal, +) { + try { + return await findSimilarNodesByText( + graph, + userMessage, + embeddingConfig, + topK, + activeNodes, + signal, + ); + } catch (error) { + if (isAbortError(error)) { + throw error; + } + console.error("[ST-BME] 向量预筛失败:", error); + return []; + } +} + +export function extractEntityAnchors(userMessage, activeNodes) { + const anchors = []; + const seen = new Set(); + + for (const node of Array.isArray(activeNodes) ? activeNodes : []) { + const candidates = [node?.fields?.name, node?.fields?.title] + .filter((value) => typeof value === "string") + .map((value) => value.trim()) + .filter((value) => value.length >= 2); + + for (const candidate of candidates) { + if (!String(userMessage || "").includes(candidate)) continue; + const key = `${node.id}:${candidate}`; + if (seen.has(key)) continue; + seen.add(key); + anchors.push({ nodeId: node.id, entity: candidate }); + break; + } + } + + return anchors; +} + +function buildLexicalTopHits(scoredNodes = [], maxCount = 5) { + return scoredNodes + .filter((item) => (Number(item?.lexicalScore) || 0) > 0) + .sort((a, b) => { + const lexicalDelta = + (Number(b?.lexicalScore) || 0) - (Number(a?.lexicalScore) || 0); + if (lexicalDelta !== 0) return lexicalDelta; + return (Number(b?.finalScore) || 0) - (Number(a?.finalScore) || 0); + }) + .slice(0, Math.max(1, maxCount)) + .map((item) => ({ + nodeId: item.nodeId, + type: item.node?.type || "", + label: + item.node?.fields?.name || + item.node?.fields?.title || + item.node?.fields?.summary || + item.nodeId, + lexicalScore: Math.round((Number(item.lexicalScore) || 0) * 1000) / 1000, + finalScore: Math.round((Number(item.finalScore) || 0) * 1000) / 1000, + })); +} + +export async function rankNodesForTaskContext({ + graph, + userMessage, + recentMessages = [], + embeddingConfig, + signal = undefined, + options = {}, +} = {}) { + const topK = clampPositiveInt(options.topK, 20); + const diffusionTopK = clampPositiveInt(options.diffusionTopK, 100); + const enableVectorPrefilter = options.enableVectorPrefilter ?? true; + const enableGraphDiffusion = options.enableGraphDiffusion ?? true; + const enableContextQueryBlend = options.enableContextQueryBlend ?? true; + const enableMultiIntent = options.enableMultiIntent ?? true; + const multiIntentMaxSegments = clampPositiveInt( + options.multiIntentMaxSegments, + 4, + ); + const contextAssistantWeight = clampRange( + options.contextAssistantWeight, + 0.2, + 0, + 1, + ); + const contextPreviousUserWeight = clampRange( + options.contextPreviousUserWeight, + 0.1, + 0, + 1, + ); + const enableLexicalBoost = options.enableLexicalBoost ?? true; + const lexicalWeight = clampRange(options.lexicalWeight, 0.18, 0, 10); + const teleportAlpha = clampRange(options.teleportAlpha, 0.15); + const enableTemporalLinks = options.enableTemporalLinks ?? true; + const temporalLinkStrength = clampRange( + options.temporalLinkStrength, + 0.2, + 0, + 1, + ); + const maxTextLength = clampPositiveInt(options.maxTextLength, 400, 32); + const weights = options.weights ?? {}; + const activeNodes = Array.isArray(options.activeNodes) + ? options.activeNodes.filter((node) => node && !node.archived) + : getActiveNodes(graph).filter((node) => node && !node.archived); + const vectorValidation = validateVectorConfig(embeddingConfig); + const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, { + enabled: enableContextQueryBlend, + assistantWeight: contextAssistantWeight, + previousUserWeight: contextPreviousUserWeight, + maxTextLength, + }); + const queryPlan = buildVectorQueryPlan(contextQueryBlend, { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }); + const lexicalQuery = buildLexicalQuerySources( + contextQueryBlend.currentText || userMessage, + { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }, + ); + const diagnostics = { + queryBlendActive: contextQueryBlend.active, + queryBlendParts: (contextQueryBlend.parts || []).map((part) => ({ + kind: part.kind, + label: part.label, + weight: part.weight, + text: createTextPreview(part.text), + length: part.text.length, + })), + queryBlendWeights: Object.fromEntries( + (contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]), + ), + segmentsUsed: [...(queryPlan.currentSegments || [])], + vectorValidation, + vectorHits: 0, + vectorMergedHits: 0, + seedCount: 0, + diffusionHits: 0, + temporalSyntheticEdgeCount: 0, + teleportAlpha, + lexicalBoostedNodes: 0, + lexicalTopHits: [], + skipReasons: [], + timings: { + vector: 0, + diffusion: 0, + }, + }; + + if (!graph || activeNodes.length === 0) { + return { + activeNodes, + contextQueryBlend, + queryPlan, + lexicalQuery, + vectorResults: [], + exactEntityAnchors: [], + diffusionResults: [], + scoredNodes: [], + diagnostics, + }; + } + + let vectorResults = []; + const vectorStartedAt = nowMs(); + if (enableVectorPrefilter && vectorValidation.valid) { + const groups = []; + for (const part of queryPlan.plan) { + for (const queryText of part.queries) { + const results = await vectorPreFilter( + graph, + queryText, + activeNodes, + embeddingConfig, + topK, + signal, + ); + groups.push(scaleVectorResults(results, part.weight || 1)); + } + } + + const merged = mergeVectorResults(groups, Math.max(topK * 2, 24)); + diagnostics.vectorHits = merged.rawHitCount; + diagnostics.vectorMergedHits = merged.results.length; + vectorResults = merged.results; + } else if (enableVectorPrefilter) { + diagnostics.skipReasons.push("vector-config-invalid"); + } + diagnostics.timings.vector = roundMs(nowMs() - vectorStartedAt); + + const exactEntityAnchors = extractEntityAnchors( + contextQueryBlend.currentText || userMessage, + activeNodes, + ); + + let diffusionResults = []; + const diffusionStartedAt = nowMs(); + if (enableGraphDiffusion) { + const seeds = [ + ...vectorResults.map((item) => ({ id: item.nodeId, energy: item.score })), + ...exactEntityAnchors.map((item) => ({ id: item.nodeId, energy: 2.0 })), + ]; + const seedMap = new Map(); + for (const seed of seeds) { + const existing = seedMap.get(seed.id) || 0; + if (seed.energy > existing) { + seedMap.set(seed.id, seed.energy); + } + } + const uniqueSeeds = [...seedMap.entries()].map(([id, energy]) => ({ + id, + energy, + })); + diagnostics.seedCount = uniqueSeeds.length; + + if (uniqueSeeds.length > 0) { + const adjacencyMap = buildTemporalAdjacencyMap(graph, { + includeTemporalLinks: enableTemporalLinks, + temporalLinkStrength, + }); + diagnostics.temporalSyntheticEdgeCount = + Number(adjacencyMap?.syntheticEdgeCount) || 0; + diffusionResults = diffuseAndRank(adjacencyMap, uniqueSeeds, { + maxSteps: 2, + decayFactor: 0.6, + topK: diffusionTopK, + teleportAlpha, + }).filter((item) => { + const node = getNode(graph, item.nodeId); + return node && !node.archived; + }); + } + } + diagnostics.diffusionHits = diffusionResults.length; + diagnostics.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); + + const scoreMap = new Map(); + for (const item of vectorResults) { + const entry = scoreMap.get(item.nodeId) || { graphScore: 0, vectorScore: 0 }; + entry.vectorScore = item.score; + scoreMap.set(item.nodeId, entry); + } + for (const item of diffusionResults) { + const entry = scoreMap.get(item.nodeId) || { graphScore: 0, vectorScore: 0 }; + entry.graphScore = item.energy; + scoreMap.set(item.nodeId, entry); + } + if (scoreMap.size === 0) { + for (const node of activeNodes) { + if (!scoreMap.has(node.id)) { + scoreMap.set(node.id, { graphScore: 0, vectorScore: 0 }); + } + } + } + + const scoredNodes = []; + for (const [nodeId, scores] of scoreMap.entries()) { + const node = getNode(graph, nodeId); + if (!node || node.archived) continue; + const lexicalScore = enableLexicalBoost + ? computeLexicalScore(node, lexicalQuery.sources) + : 0; + const finalScore = hybridScore( + { + graphScore: scores.graphScore, + vectorScore: scores.vectorScore, + lexicalScore, + importance: node.importance, + createdTime: node.createdTime, + }, + { + ...weights, + lexicalWeight: enableLexicalBoost ? lexicalWeight : 0, + }, + ); + + scoredNodes.push({ + nodeId, + node, + graphScore: scores.graphScore, + vectorScore: scores.vectorScore, + lexicalScore, + finalScore, + weightedScore: finalScore, + }); + } + + scoredNodes.sort((left, right) => { + const weightedDelta = + (Number(right.weightedScore) || 0) - (Number(left.weightedScore) || 0); + if (weightedDelta !== 0) return weightedDelta; + const finalDelta = + (Number(right.finalScore) || 0) - (Number(left.finalScore) || 0); + if (finalDelta !== 0) return finalDelta; + const lexicalDelta = + (Number(right.lexicalScore) || 0) - (Number(left.lexicalScore) || 0); + if (lexicalDelta !== 0) return lexicalDelta; + return String(left.nodeId).localeCompare(String(right.nodeId)); + }); + + diagnostics.lexicalBoostedNodes = scoredNodes.filter( + (item) => (Number(item.lexicalScore) || 0) > 0, + ).length; + diagnostics.lexicalTopHits = buildLexicalTopHits(scoredNodes); + + return { + activeNodes, + contextQueryBlend, + queryPlan, + lexicalQuery, + vectorResults, + exactEntityAnchors, + diffusionResults, + scoredNodes, + diagnostics, + }; +} diff --git a/runtime/settings-defaults.js b/runtime/settings-defaults.js index 2ba320a..e216167 100644 --- a/runtime/settings-defaults.js +++ b/runtime/settings-defaults.js @@ -29,6 +29,7 @@ export const defaultSettings = { extractWorldbookMode: "active", extractIncludeStoryTime: true, extractIncludeSummaries: true, + extractActionMode: "pending", // 召回设置 recallEnabled: true, diff --git a/tests/default-settings.mjs b/tests/default-settings.mjs index f56a161..d05be94 100644 --- a/tests/default-settings.mjs +++ b/tests/default-settings.mjs @@ -6,6 +6,7 @@ import { } from "../runtime/settings-defaults.js"; assert.equal(defaultSettings.extractContextTurns, 2); +assert.equal(defaultSettings.extractActionMode, "pending"); assert.equal(defaultSettings.extractAutoDelayLatestAssistant, false); assert.equal(defaultSettings.recallTopK, 20); assert.equal(defaultSettings.recallMaxNodes, 8); diff --git a/tests/extraction-context-only-flag.mjs b/tests/extraction-context-only-flag.mjs new file mode 100644 index 0000000..10e517a --- /dev/null +++ b/tests/extraction-context-only-flag.mjs @@ -0,0 +1,168 @@ +import assert from "node:assert/strict"; +import { + buildExtractionMessages, +} from "../maintenance/chat-history.js"; +import { + buildExtractionInputContext, + formatExtractionTranscript, +} from "../maintenance/extraction-context.js"; + +// ─── buildExtractionMessages: isContextOnly flag ─── + +const chat = [ + { is_user: false, is_system: true, mes: "greeting" }, + { is_user: true, is_system: false, mes: "user-1" }, + { is_user: false, is_system: false, mes: "assistant-1" }, + { is_user: true, is_system: false, mes: "user-2" }, + { is_user: false, is_system: false, mes: "assistant-2" }, + { is_user: true, is_system: false, mes: "user-3" }, + { is_user: false, is_system: false, mes: "assistant-3" }, +]; + +{ + const messages = buildExtractionMessages(chat, 4, 6, { + extractContextTurns: 2, + }); + const contextOnly = messages.filter((m) => m.isContextOnly); + const target = messages.filter((m) => !m.isContextOnly); + + assert.ok( + contextOnly.length > 0, + "should have context-only messages when extractContextTurns > 0", + ); + assert.ok( + target.length > 0, + "should have extraction target messages", + ); + assert.ok( + contextOnly.every((m) => m.seq < 4), + "context-only messages should have seq < startIdx", + ); + assert.ok( + target.every((m) => m.seq >= 4), + "target messages should have seq >= startIdx", + ); + console.log(" ✓ buildExtractionMessages: isContextOnly flag marks context vs target"); +} + +{ + const messages = buildExtractionMessages(chat, 2, 6, { + extractContextTurns: 0, + }); + const contextOnly = messages.filter((m) => m.isContextOnly); + assert.equal( + contextOnly.length, + 0, + "no context-only messages when extractContextTurns=0 and startIdx=2", + ); + console.log(" ✓ buildExtractionMessages: no context-only when contextTurns=0"); +} + +{ + const messages = buildExtractionMessages(chat, 1, 6, { + extractContextTurns: 2, + }); + const contextOnly = messages.filter((m) => m.isContextOnly); + assert.equal( + contextOnly.length, + 0, + "no context-only when startIdx is already at the beginning", + ); + console.log(" ✓ buildExtractionMessages: no context-only when startIdx at beginning"); +} + +// ─── formatExtractionTranscript: section dividers ─── + +{ + const mixed = [ + { seq: 1, role: "user", content: "context user", speaker: "A", isContextOnly: true }, + { + seq: 2, + role: "assistant", + content: "context ai", + speaker: "B", + hideSpeakerLabel: true, + isContextOnly: true, + }, + { seq: 3, role: "user", content: "target user", speaker: "A", isContextOnly: false }, + { + seq: 4, + role: "assistant", + content: "target ai", + speaker: "B", + hideSpeakerLabel: true, + isContextOnly: false, + }, + ]; + const transcript = formatExtractionTranscript(mixed); + assert.match(transcript, /已提取过/, "transcript should contain context review header"); + assert.match(transcript, /本次需要提取/, "transcript should contain extraction target header"); + assert.ok( + transcript.indexOf("已提取过") < transcript.indexOf("本次需要提取"), + "context header should appear before target header", + ); + assert.match(transcript, /#1.*context user/, "context message should appear"); + assert.match(transcript, /#3.*target user/, "target message should appear"); + assert.match(transcript, /#2 \[assistant\]: context ai/, "assistant card name should be hidden"); + assert.doesNotMatch(transcript, /#2 \[assistant\|B\]:/, "assistant card name should not be rendered"); + console.log(" ✓ formatExtractionTranscript: section dividers for mixed context/target"); +} + +{ + const allTarget = [ + { seq: 3, role: "user", content: "user msg", speaker: "A", isContextOnly: false }, + { seq: 4, role: "assistant", content: "ai msg", speaker: "B", isContextOnly: false }, + ]; + const transcript = formatExtractionTranscript(allTarget); + assert.doesNotMatch(transcript, /已提取过/, "no context header when all are target"); + assert.doesNotMatch(transcript, /本次需要提取/, "no target header when all are target"); + console.log(" ✓ formatExtractionTranscript: no dividers when all messages are targets"); +} + +{ + const allContext = [ + { seq: 1, role: "user", content: "user msg", speaker: "A", isContextOnly: true }, + { seq: 2, role: "assistant", content: "ai msg", speaker: "B", isContextOnly: true }, + ]; + const transcript = formatExtractionTranscript(allContext); + assert.doesNotMatch(transcript, /已提取过/, "no dividers when all are context-only"); + assert.doesNotMatch(transcript, /本次需要提取/, "no dividers when all are context-only"); + console.log(" ✓ formatExtractionTranscript: no dividers when all messages are context-only"); +} + +// ─── buildExtractionInputContext: isContextOnly propagation ─── + +{ + const inputMessages = [ + { seq: 1, role: "user", content: "old question", name: "A", speaker: "A", isContextOnly: true }, + { seq: 2, role: "assistant", content: "old answer", name: "B", speaker: "B", isContextOnly: true }, + { seq: 3, role: "user", content: "new question", name: "A", speaker: "A", isContextOnly: false }, + { seq: 4, role: "assistant", content: "new answer", name: "B", speaker: "B", isContextOnly: false }, + ]; + const result = buildExtractionInputContext(inputMessages, { + settings: {}, + userName: "A", + charName: "B", + }); + const contextFiltered = result.filteredMessages.filter((m) => m.isContextOnly); + const targetFiltered = result.filteredMessages.filter((m) => !m.isContextOnly); + assert.equal(contextFiltered.length, 2, "context messages propagated through filtering"); + assert.equal(targetFiltered.length, 2, "target messages propagated through filtering"); + assert.equal( + result.filteredMessages.find((m) => m.seq === 2)?.hideSpeakerLabel, + true, + "active character assistant label should be hidden", + ); + assert.equal( + result.filteredMessages.find((m) => m.seq === 1)?.hideSpeakerLabel, + false, + "user label should remain visible", + ); + assert.match(result.filteredTranscript, /已提取过/, "transcript includes context header"); + assert.match(result.filteredTranscript, /本次需要提取/, "transcript includes target header"); + assert.match(result.filteredTranscript, /#2 \[assistant\]: old answer/, "assistant transcript should hide character name"); + assert.doesNotMatch(result.filteredTranscript, /#2 \[assistant\|B\]:/, "assistant transcript should not show character name"); + console.log(" ✓ buildExtractionInputContext: isContextOnly propagated to filteredMessages and transcript"); +} + +console.log("extraction-context-only-flag tests passed"); diff --git a/tests/extractor-input-context.mjs b/tests/extractor-input-context.mjs index 683e3ee..6ffd237 100644 --- a/tests/extractor-input-context.mjs +++ b/tests/extractor-input-context.mjs @@ -141,8 +141,9 @@ try { (message) => message.sourceKey === "recentMessages", ); assert.ok(recentBlock); - assert.match(String(recentBlock?.content || ""), /#10 \[assistant\|艾琳\]: 继续说明/); + assert.match(String(recentBlock?.content || ""), /#10 \[assistant\]: 继续说明/); assert.match(String(recentBlock?.content || ""), /#11 \[user\|玩家\]: 用户输入/); + assert.doesNotMatch(String(recentBlock?.content || ""), /#10 \[assistant\|艾琳\]:/); assert.doesNotMatch(String(recentBlock?.content || ""), /隐式思维|/); } finally { restore(); diff --git a/tests/extractor-phase3-layered-context.mjs b/tests/extractor-phase3-layered-context.mjs index f83a305..2b9709e 100644 --- a/tests/extractor-phase3-layered-context.mjs +++ b/tests/extractor-phase3-layered-context.mjs @@ -62,7 +62,7 @@ installResolveHooks([ }, ]); -const { createEmptyGraph, addNode, createNode } = await import("../graph/graph.js"); +const { addEdge, addNode, createEdge, createEmptyGraph, createNode } = await import("../graph/graph.js"); const { DEFAULT_NODE_SCHEMA } = await import("../graph/schema.js"); const { extractMemories } = await import("../maintenance/extractor.js"); const { appendSummaryEntry } = await import("../graph/summary-state.js"); @@ -165,6 +165,90 @@ function collectAllPromptContent(captured) { } } +{ + const graph = createEmptyGraph(); + let captured = null; + const restore = setTestOverrides({ + llm: { + async callLLMForJSON(payload) { + captured = payload; + return { operations: [], cognitionUpdates: [], regionUpdates: {} }; + }, + }, + }); + + try { + const result = await extractMemories({ + graph, + messages: [ + { + seq: 10, + role: "user", + content: "第一轮消息", + name: "玩家", + speaker: "玩家", + isContextOnly: true, + }, + { + seq: 11, + role: "assistant", + content: "第一轮回复", + name: "艾琳", + speaker: "艾琳", + isContextOnly: true, + }, + { + seq: 12, + role: "user", + content: "第二轮消息", + name: "玩家", + speaker: "玩家", + isContextOnly: false, + }, + { + seq: 13, + role: "assistant", + content: "第二轮回复", + name: "艾琳", + speaker: "艾琳", + isContextOnly: false, + }, + ], + startSeq: 12, + endSeq: 13, + schema: DEFAULT_NODE_SCHEMA, + embeddingConfig: null, + settings: { ...defaultSettings }, + }); + + assert.equal(result.success, true); + assert.ok(captured); + + const recentMessages = (Array.isArray(captured.promptMessages) + ? captured.promptMessages + : [] + ).filter( + (m) => m.sourceKey === "recentMessages", + ); + assert.equal(recentMessages.length, 2, "recentMessages should split into 2 section system messages"); + assert.equal(recentMessages[0]?.role, "system"); + assert.equal(recentMessages[0]?.transcriptSection, "context"); + assert.match(String(recentMessages[0]?.content || ""), /^--- 以下是上下文回顾(已提取过),仅供理解剧情 ---/); + assert.match(String(recentMessages[0]?.content || ""), /#10 \[user\|玩家\]: 第一轮消息/); + assert.equal(recentMessages[1]?.role, "system"); + assert.equal(recentMessages[1]?.transcriptSection, "target"); + assert.match(String(recentMessages[1]?.content || ""), /^--- 以下是本次需要提取记忆的新对话内容 ---/); + assert.match(String(recentMessages[1]?.content || ""), /#12 \[user\|玩家\]: 第二轮消息/); + assert.ok( + recentMessages[0].content.includes("已提取过") && + recentMessages[1].content.includes("本次需要提取"), + "context and target sections should each be emitted as a single system message", + ); + } finally { + restore(); + } +} + // ── Test 2: extractRecentMessageCap limits messages ── { const graph = createEmptyGraph(); @@ -382,6 +466,97 @@ function collectAllPromptContent(captured) { } // ── Test 7: new settings exist in defaults ── +{ + const graph = createEmptyGraph(); + const confessionNode = addNode( + graph, + createNode({ + type: "event", + seq: 3, + importance: 8, + fields: { + title: "中文告白", + summary: "她认真地要求你再说一遍喜欢她。", + }, + }), + ); + const relationshipNode = addNode( + graph, + createNode({ + type: "thread", + seq: 4, + importance: 7, + fields: { + title: "感情升温", + summary: "两人的关系在这次告白后快速拉近。", + }, + }), + ); + addEdge( + graph, + createEdge({ + fromId: confessionNode.id, + toId: relationshipNode.id, + relation: "supports", + strength: 0.9, + }), + ); + + let captured = null; + const restore = setTestOverrides({ + llm: { + async callLLMForJSON(payload) { + captured = payload; + return { operations: [], cognitionUpdates: [], regionUpdates: {} }; + }, + }, + }); + + try { + const result = await extractMemories({ + graph, + messages: [ + { + seq: 10, + role: "user", + content: "中文告白之后,她还是很害羞。", + name: "玩家", + speaker: "玩家", + }, + { + seq: 11, + role: "assistant", + content: "这次中文告白让你们的感情升温了。", + name: "艾琳", + speaker: "艾琳", + }, + ], + startSeq: 10, + endSeq: 11, + schema: DEFAULT_NODE_SCHEMA, + embeddingConfig: null, + settings: { ...defaultSettings }, + }); + + assert.equal(result.success, true); + assert.ok(captured); + + const graphStatsBlock = (Array.isArray(captured.promptMessages) ? captured.promptMessages : []).find( + (m) => m.sourceKey === "graphStats", + ); + assert.ok(graphStatsBlock, "graphStats block should exist"); + const graphStatsContent = String(graphStatsBlock.content || ""); + assert.match(graphStatsContent, /### 图谱节点统计/); + assert.match(graphStatsContent, /事件: 1/); + assert.match(graphStatsContent, /主线: 1/); + assert.match(graphStatsContent, /\[G1\|事件\] 中文告白/); + assert.doesNotMatch(graphStatsContent, new RegExp(confessionNode.id.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"))); + } finally { + restore(); + } +} + +// ── Test 8: new settings exist in defaults ── { assert.equal(defaultSettings.extractRecentMessageCap, 0); assert.equal(defaultSettings.extractPromptStructuredMode, "both"); diff --git a/tests/extractor-phase5-context-fidelity.mjs b/tests/extractor-phase5-context-fidelity.mjs index ebb15bc..37d53bd 100644 --- a/tests/extractor-phase5-context-fidelity.mjs +++ b/tests/extractor-phase5-context-fidelity.mjs @@ -286,7 +286,7 @@ try { ).find((message) => message.sourceKey === "recentMessages"); assert.ok(recentBlock, "recentMessages block should exist"); const recentContent = String(recentBlock?.content || ""); - assert.match(recentContent, /#30 \[assistant\|艾琳\]: 艾琳说:去调查蓝钥匙。/); + assert.match(recentContent, /#30 \[assistant\]: 艾琳说:去调查蓝钥匙。/); assert.match( recentContent, /#31 \[assistant\|旁白\]: 旁白补充:雨夜<\/status>巷子很安静。/, @@ -351,7 +351,7 @@ try { : [] ).find((message) => message.sourceKey === "recentMessages"); assert.ok(recentBlock, "recentMessages block should still exist when worldbook is disabled"); - assert.match(String(recentBlock?.content || ""), /#30 \[assistant\|艾琳\]: 艾琳说:去调查蓝钥匙。/); + assert.match(String(recentBlock?.content || ""), /#30 \[assistant\]: 艾琳说:去调查蓝钥匙。/); } finally { restore(); } diff --git a/tests/p0-regressions.mjs b/tests/p0-regressions.mjs index b68deca..ece6400 100644 --- a/tests/p0-regressions.mjs +++ b/tests/p0-regressions.mjs @@ -162,6 +162,7 @@ const { generateSynopsis, } = await import("../maintenance/extractor.js"); const { consolidateMemories } = await import("../maintenance/consolidator.js"); +const { retrieve } = await import("../retrieval/retriever.js"); const { createBatchJournalEntry, buildReverseJournalRecoveryPlan, @@ -169,6 +170,10 @@ const { rollbackBatch, } = await import("../runtime/runtime-state.js"); const { createDefaultTaskProfiles } = await import("../prompting/prompt-profiles.js"); +const { + EXTRACTION_CONTEXT_REVIEW_HEADER, + RECALL_TARGET_CONTENT_HEADER, +} = await import("../prompting/prompt-builder.js"); const extensionsApi = await import("../../../../extensions.js"); const llm = await import("../llm/llm.js"); const embedding = await import("../vector/embedding.js"); @@ -1997,14 +2002,34 @@ async function testCompressTypeAcceptsTopLevelFieldsResult() { keepRecentLeaves: 0, }, }; + const compressionSchema = [ + typeDef, + { + id: "thread", + label: "主线", + columns: [{ name: "title" }, { name: "summary" }, { name: "status" }], + }, + ]; const first = makeEvent(1, "事件甲"); const second = makeEvent(2, "事件乙"); + const relatedThread = createNode({ + type: "thread", + seq: 3, + fields: { + title: "事件甲余波", + summary: "Alice 被卷入的后续波动。", + status: "active", + }, + }); addNode(graph, first); addNode(graph, second); + addNode(graph, relatedThread); + const captured = []; const restoreOverrides = pushTestOverrides({ llm: { - async callLLMForJSON() { + async callLLMForJSON(params = {}) { + captured.push(params); return { title: "压缩事件", summary: "顶层返回的合并摘要", @@ -2020,8 +2045,12 @@ async function testCompressTypeAcceptsTopLevelFieldsResult() { graph, typeDef, embeddingConfig: null, + schema: compressionSchema, force: true, - settings: {}, + settings: { + taskProfilesVersion: 3, + taskProfiles: createDefaultTaskProfiles(), + }, }); assert.equal(result.created, 1); const compressed = graph.nodes.find( @@ -2029,6 +2058,21 @@ async function testCompressTypeAcceptsTopLevelFieldsResult() { ); assert.equal(compressed?.fields?.summary, "顶层返回的合并摘要"); assert.equal(compressed?.fields?.title, "压缩事件"); + assert.equal(captured.length, 1); + const graphStatsBlock = (Array.isArray(captured[0].promptMessages) + ? captured[0].promptMessages + : [] + ).find((message) => message.sourceKey === "graphStats"); + assert.ok(graphStatsBlock, "compress graphStats block should exist"); + const graphStatsContent = String(graphStatsBlock.content || ""); + assert.match(graphStatsContent, /### 图谱节点统计/); + assert.match(graphStatsContent, /事件: 2/); + assert.match(graphStatsContent, /主线: 1/); + assert.match(graphStatsContent, /\[G1\|主线\] 事件甲余波/); + assert.doesNotMatch( + graphStatsContent, + new RegExp(relatedThread.id.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")), + ); } finally { restoreOverrides(); } @@ -2060,17 +2104,22 @@ async function testConsolidatorMergeFallbackKeepsNodeWhenTargetMissing() { addNode(graph, target); addNode(graph, incoming); + const captured = []; const restoreOverrides = pushTestOverrides({ embedding: { async embedBatch() { return [[0.2, 0.3]]; }, + async embedText() { + return [0.2, 0.3]; + }, searchSimilar() { return [{ nodeId: target.id, score: 0.99 }]; }, }, llm: { - async callLLMForJSON() { + async callLLMForJSON(params = {}) { + captured.push(params); return { results: [ { @@ -2095,13 +2144,31 @@ async function testConsolidatorMergeFallbackKeepsNodeWhenTargetMissing() { apiUrl: "https://example.com/v1", model: "text-embedding-3-small", }, - settings: {}, + schema, + settings: { + taskProfilesVersion: 3, + taskProfiles: createDefaultTaskProfiles(), + }, }); assert.equal(stats.merged, 0); assert.equal(stats.kept, 1); assert.equal(incoming.archived, false); assert.deepEqual(target.embedding, [0.9, 0.1]); + assert.equal(captured.length, 1); + const graphStatsBlock = (Array.isArray(captured[0].promptMessages) + ? captured[0].promptMessages + : [] + ).find((message) => message.sourceKey === "graphStats"); + assert.ok(graphStatsBlock, "consolidation graphStats block should exist"); + const graphStatsContent = String(graphStatsBlock.content || ""); + assert.match(graphStatsContent, /### 图谱节点统计/); + assert.match(graphStatsContent, /事件: 2/); + assert.match(graphStatsContent, /\[G1\|事件\] 旧记忆/); + assert.doesNotMatch( + graphStatsContent, + new RegExp(target.id.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")), + ); } finally { restoreOverrides(); } @@ -2260,6 +2327,9 @@ async function testConsolidatorMergeUpdatesSeqRange() { async embedBatch() { return [[0.4, 0.5]]; }, + async embedText() { + return [0.4, 0.5]; + }, searchSimilar() { return [{ nodeId: target.id, score: 0.99 }]; }, @@ -6173,6 +6243,84 @@ async function testSynopsisUsesPromptMessagesWithoutFallbackSystemPrompt() { } } +async function testRecallUsesSectionedPromptMessagesForContextAndTarget() { + const graph = createEmptyGraph(); + addNode(graph, makeEvent(1, "仓库争执")); + addNode(graph, makeEvent(2, "走廊追问")); + + const captured = []; + const restoreOverrides = pushTestOverrides({ + llm: { + async callLLMForJSON(params = {}) { + captured.push(params); + return { + selected_keys: ["R1"], + reason: "R1: 与当前追问直接相关", + active_owner_keys: [], + active_owner_scores: [], + }; + }, + }, + }); + + try { + const result = await retrieve({ + graph, + userMessage: "她为什么突然改口?", + recentMessages: [ + "[assistant]: 她先否认自己去过仓库。", + "[user]: 我记得她当时很紧张。", + "[user]: 她为什么突然改口?", + ], + embeddingConfig: null, + schema, + settings: { + taskProfilesVersion: 3, + taskProfiles: createDefaultTaskProfiles(), + }, + options: { + topK: 4, + maxRecallNodes: 2, + enableLLMRecall: true, + enableVectorPrefilter: false, + enableGraphDiffusion: false, + llmCandidatePool: 2, + enableScopedMemory: false, + enablePovMemory: false, + enableRegionScopedObjective: false, + enableCognitiveMemory: false, + enableSpatialAdjacency: false, + enableStoryTimeline: false, + injectStoryTimeLabel: false, + injectUserPovMemory: false, + injectObjectiveGlobalMemory: false, + enableContextQueryBlend: true, + }, + }); + + assert.ok(Array.isArray(result?.selectedNodeIds)); + assert.equal(captured.length, 1); + const promptMessages = Array.isArray(captured[0].promptMessages) + ? captured[0].promptMessages + : []; + const recentMessageSections = promptMessages.filter( + (message) => message.sourceKey === "recentMessages", + ); + assert.equal(recentMessageSections.length, 2); + assert.equal(recentMessageSections[0].role, "system"); + assert.equal(recentMessageSections[1].role, "system"); + assert.equal(recentMessageSections[0].transcriptSection, "context"); + assert.equal(recentMessageSections[1].transcriptSection, "target"); + assert.match(recentMessageSections[0].content, new RegExp(EXTRACTION_CONTEXT_REVIEW_HEADER.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"))); + assert.match(recentMessageSections[1].content, new RegExp(RECALL_TARGET_CONTENT_HEADER.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"))); + assert.match(recentMessageSections[0].content, /她先否认自己去过仓库/); + assert.match(recentMessageSections[0].content, /我记得她当时很紧张/); + assert.match(recentMessageSections[1].content, /她为什么突然改口/); + } finally { + restoreOverrides(); + } +} + async function testReflectionUsesPromptMessagesWithoutFallbackSystemPrompt() { const graph = createEmptyGraph(); addNode( @@ -6212,17 +6360,26 @@ async function testReflectionUsesPromptMessagesWithoutFallbackSystemPrompt() { }, }), ); + const threadNode = createNode({ + type: "thread", + seq: 5, + fields: { + title: "信任危机", + status: "active", + }, + }); addNode( graph, - createNode({ - type: "thread", - seq: 5, - fields: { - title: "信任危机", - status: "active", - }, - }), + threadNode, ); + const reflectionSchema = [ + ...schema, + { + id: "thread", + label: "主线", + columns: [{ name: "title" }, { name: "status" }], + }, + ]; const captured = []; const restoreOverrides = pushTestOverrides({ @@ -6243,6 +6400,7 @@ async function testReflectionUsesPromptMessagesWithoutFallbackSystemPrompt() { const result = await generateReflection({ graph, currentSeq: 5, + schema: reflectionSchema, settings: { taskProfilesVersion: 3, taskProfiles: createDefaultTaskProfiles(), @@ -6254,6 +6412,21 @@ async function testReflectionUsesPromptMessagesWithoutFallbackSystemPrompt() { assert.equal(Array.isArray(captured[0].promptMessages), true); assert.ok(captured[0].promptMessages.length > 0); assert.equal(captured[0].systemPrompt, ""); + const graphStatsBlock = (Array.isArray(captured[0].promptMessages) + ? captured[0].promptMessages + : [] + ).find((message) => message.sourceKey === "graphStats"); + assert.ok(graphStatsBlock, "reflection graphStats block should exist"); + const graphStatsContent = String(graphStatsBlock.content || ""); + assert.match(graphStatsContent, /### 图谱节点统计/); + assert.match(graphStatsContent, /事件: 2/); + assert.match(graphStatsContent, /角色: 1/); + assert.match(graphStatsContent, /主线: 1/); + assert.match(graphStatsContent, /\[G1\|主线\] 信任危机/); + assert.doesNotMatch( + graphStatsContent, + new RegExp(threadNode.id.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")), + ); const reflectionNode = graph.nodes.find((node) => node.id === result); assert.equal( reflectionNode?.fields?.insight, @@ -6646,6 +6819,7 @@ await testLlmDebugSnapshotRedactsSecretsBeforeStorage(); await testEmbeddingUsesConfigTimeoutInsteadOfDefault(); await testLlmOutputRegexCleansResponseBeforeJsonParse(); await testSynopsisUsesPromptMessagesWithoutFallbackSystemPrompt(); +await testRecallUsesSectionedPromptMessagesForContextAndTarget(); await testReflectionUsesPromptMessagesWithoutFallbackSystemPrompt(); await testManualCompressSkipsWithoutCandidatesAndDoesNotPretendItRan(); await testManualCompressUsesForcedCompressionAndPersistsRealMutation(); diff --git a/tests/prompt-builder-mixed-transcript.mjs b/tests/prompt-builder-mixed-transcript.mjs index fa2bd83..eda3cd4 100644 --- a/tests/prompt-builder-mixed-transcript.mjs +++ b/tests/prompt-builder-mixed-transcript.mjs @@ -117,6 +117,8 @@ const promptBuild = await buildTaskPrompt(settings, "extract", { content: "继续说明", name: "艾琳", speaker: "艾琳", + hideSpeakerLabel: true, + isContextOnly: true, }, { seq: 42, @@ -124,6 +126,7 @@ const promptBuild = await buildTaskPrompt(settings, "extract", { content: "用户输入", name: "玩家", speaker: "玩家", + isContextOnly: false, }, ], graphStats: "node_count=1", @@ -131,17 +134,41 @@ const promptBuild = await buildTaskPrompt(settings, "extract", { currentRange: "41 ~ 42", }); const payload = buildTaskLlmPayload(promptBuild, "fallback-user"); -const recentBlock = payload.promptMessages.find( +const recentMessages = payload.promptMessages.filter( (message) => message.sourceKey === "recentMessages", ); -assert.match(String(recentBlock?.content || ""), /#41 \[assistant\|艾琳\]: 助手已净化/); -assert.match(String(recentBlock?.content || ""), /#42 \[user\|玩家\]: 用户已净化/); +assert.deepEqual( + recentMessages.map((message) => ({ + role: message.role, + sourceKey: message.sourceKey, + transcriptSection: message.transcriptSection, + transcriptSectionPart: message.transcriptSectionPart, + })), + [ + { + role: "system", + sourceKey: "recentMessages", + transcriptSection: "context", + transcriptSectionPart: "section", + }, + { + role: "system", + sourceKey: "recentMessages", + transcriptSection: "target", + transcriptSectionPart: "section", + }, + ], +); +assert.match(String(recentMessages[0]?.content || ""), /^--- 以下是上下文回顾(已提取过),仅供理解剧情 ---/); +assert.match(String(recentMessages[0]?.content || ""), /#41 \[assistant\]: 助手已净化/); +assert.match(String(recentMessages[1]?.content || ""), /^--- 以下是本次需要提取记忆的新对话内容 ---/); +assert.match(String(recentMessages[1]?.content || ""), /#42 \[user\|玩家\]: 用户已净化/); assert.doesNotMatch( - String(recentBlock?.content || ""), - /#41 \[assistant\|艾琳\]: 用户已净化/, + String(recentMessages[0]?.content || ""), + /#41 \[assistant\|艾琳\]:/, ); assert.doesNotMatch( - String(recentBlock?.content || ""), + String(recentMessages[1]?.content || ""), /#42 \[user\|玩家\]: 助手已净化/, ); diff --git a/tests/prompt-node-references.mjs b/tests/prompt-node-references.mjs new file mode 100644 index 0000000..d9d0483 --- /dev/null +++ b/tests/prompt-node-references.mjs @@ -0,0 +1,54 @@ +import assert from "node:assert/strict"; + +const { + createPromptNodeReferenceMap, + getPromptNodeLabel, + resolvePromptNodeId, +} = await import("../prompting/prompt-node-references.js"); + +const rawNodeId = "550e8400-e29b-41d4-a716-446655440000"; +const map = createPromptNodeReferenceMap( + [ + { + nodeId: rawNodeId, + node: { + id: rawNodeId, + type: "event", + fields: { + title: "这是一个非常非常长的节点标题,用于测试提取提示里的标签截断行为", + }, + }, + score: 0.91, + }, + { + node: { + id: "node-2", + type: "thread", + fields: { + summary: "关系持续升温", + }, + }, + score: 0.77, + }, + ], + { + prefix: "G", + maxLength: 12, + buildMeta: ({ entry }) => ({ + score: entry.score, + }), + }, +); + +assert.deepEqual(Object.keys(map.keyToNodeId), ["G1", "G2"]); +assert.equal(map.keyToNodeId.G1, rawNodeId); +assert.equal(map.nodeIdToKey[rawNodeId], "G1"); +assert.equal(resolvePromptNodeId({ nodeId: rawNodeId }), rawNodeId); +assert.equal(resolvePromptNodeId({ node: { id: "node-2" } }), "node-2"); +assert.equal(getPromptNodeLabel({ id: "node-3", fields: { title: "短标题" } }), "短标题"); +assert.equal(map.keyToMeta.G1.score, 0.91); +assert.match(map.keyToMeta.G1.label, /^这是一个非常非常长的节…$/); +assert.equal(map.keyToMeta.G2.label, "关系持续升温"); +assert.equal(map.keyToMeta.G1.nodeId, rawNodeId); + +console.log("prompt-node-references tests passed"); diff --git a/tests/retrieval-config.mjs b/tests/retrieval-config.mjs index 1bcac6d..9ec80e5 100644 --- a/tests/retrieval-config.mjs +++ b/tests/retrieval-config.mjs @@ -78,6 +78,485 @@ function createGraphHelpers(graph) { }; } +function getPromptNodeLabel(node = {}, { maxLength = 32 } = {}) { + const raw = String( + node?.fields?.title || + node?.fields?.name || + node?.fields?.summary || + node?.fields?.insight || + node?.fields?.belief || + node?.id || + "—", + ) + .replace(/\s+/g, " ") + .trim(); + if (!raw) return "—"; + if (!Number.isFinite(maxLength) || maxLength < 2 || raw.length <= maxLength) { + return raw; + } + return `${raw.slice(0, Math.max(1, maxLength - 1)).trimEnd()}…`; +} + +function createPromptNodeReferenceMap(entries = [], { prefix = "N", buildMeta = null } = {}) { + const keyToNodeId = {}; + const keyToMeta = {}; + const nodeIdToKey = {}; + const references = []; + for (const [index, entry] of (Array.isArray(entries) ? entries : []).entries()) { + const node = entry?.node || entry || {}; + const nodeId = String(entry?.nodeId || node?.id || "").trim(); + if (!nodeId || nodeIdToKey[nodeId]) continue; + const key = `${String(prefix || "N").trim() || "N"}${references.length + 1}`; + keyToNodeId[key] = nodeId; + nodeIdToKey[nodeId] = key; + keyToMeta[key] = { + nodeId, + type: String(node?.type || ""), + label: getPromptNodeLabel(node), + ...((typeof buildMeta === "function" + ? buildMeta({ entry, node, nodeId, key, index, label: getPromptNodeLabel(node) }) + : {}) || {}), + }; + references.push({ key, nodeId, node, meta: keyToMeta[key] }); + } + return { + keyToNodeId, + keyToMeta, + nodeIdToKey, + references, + }; +} + +function normalizeQueryText(value, maxLength = 400) { + const normalized = String(value ?? "") + .replace(/\r\n/g, "\n") + .replace(/\s+/g, " ") + .trim(); + if (!normalized) return ""; + return normalized.slice(0, Math.max(1, maxLength)); +} + +function splitIntentSegments(text, { maxSegments = 4, minLength = 1 } = {}) { + const raw = String(text || "").trim(); + if (!raw) return []; + const segments = raw + .split(/[,,。.;;!!??\n]+|(?:和|顺便|另外|还有|对了|然后|而且|并且|同时)/) + .map((item) => item.trim()) + .filter((item) => item.length >= minLength); + return uniqueStrings(segments).slice(0, Math.max(1, maxSegments)); +} + +function uniqueStrings(values = [], maxLength = 400) { + const result = []; + const seen = new Set(); + for (const value of values) { + const text = normalizeQueryText(value, maxLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + result.push(text); + } + return result; +} + +function mergeVectorResults(groups, limit) { + const merged = new Map(); + let rawHitCount = 0; + for (const group of groups) { + for (const item of group) { + rawHitCount += 1; + const existing = merged.get(item.nodeId); + if (!existing || item.score > existing.score) { + merged.set(item.nodeId, item); + } + } + } + return { + rawHitCount, + results: [...merged.values()].slice(0, limit), + }; +} + +function parseContextLine(line = "") { + const raw = String(line ?? "").trim(); + if (!raw) return null; + const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i); + if (bracketMatch) { + const role = String(bracketMatch[1] || "").toLowerCase(); + const text = normalizeQueryText(bracketMatch[2] || ""); + return text ? { role, text } : null; + } + const plainMatch = raw.match(/^(user|assistant|用户|助手|ai)\s*[::]\s*([\s\S]*)$/i); + if (!plainMatch) return null; + const roleToken = String(plainMatch[1] || "").toLowerCase(); + const role = + roleToken === "assistant" || roleToken === "助手" || roleToken === "ai" + ? "assistant" + : "user"; + const text = normalizeQueryText(plainMatch[2] || ""); + return text ? { role, text } : null; +} + +function buildContextQueryBlend( + userMessage, + recentMessages = [], + { + enabled = true, + assistantWeight = 0.2, + previousUserWeight = 0.1, + maxTextLength = 400, + } = {}, +) { + const currentText = normalizeQueryText(userMessage, maxTextLength); + let assistantText = ""; + let previousUserText = ""; + const parsedMessages = Array.isArray(recentMessages) + ? recentMessages.map((line) => parseContextLine(line)).filter(Boolean) + : []; + + for (let index = parsedMessages.length - 1; index >= 0; index -= 1) { + const item = parsedMessages[index]; + if (!assistantText && item.role === "assistant") { + assistantText = normalizeQueryText(item.text, maxTextLength); + } + if ( + !previousUserText && + item.role === "user" && + normalizeQueryText(item.text, maxTextLength).toLowerCase() !== + currentText.toLowerCase() + ) { + previousUserText = normalizeQueryText(item.text, maxTextLength); + } + if (assistantText && previousUserText) break; + } + + const currentWeight = Math.max( + 0, + 1 - Number(assistantWeight || 0) - Number(previousUserWeight || 0), + ); + const rawParts = [ + { + kind: "currentUser", + label: "当前用户消息", + text: currentText, + weight: enabled ? currentWeight : 1, + }, + ]; + if (enabled && assistantText) { + rawParts.push({ + kind: "assistantContext", + label: "最近 assistant 回复", + text: assistantText, + weight: Number(assistantWeight || 0), + }); + } + if (enabled && previousUserText) { + rawParts.push({ + kind: "previousUser", + label: "上一条 user 消息", + text: previousUserText, + weight: Number(previousUserWeight || 0), + }); + } + + const dedupedParts = []; + const seen = new Set(); + for (const part of rawParts) { + const text = normalizeQueryText(part.text, maxTextLength); + const key = text.toLowerCase(); + if (!text || seen.has(key)) continue; + seen.add(key); + dedupedParts.push({ ...part, text }); + } + + const totalWeight = dedupedParts.reduce( + (sum, part) => sum + Math.max(0, Number(part.weight) || 0), + 0, + ); + const parts = dedupedParts.map((part) => ({ + ...part, + weight: + totalWeight > 0 + ? Math.round((Math.max(0, Number(part.weight) || 0) / totalWeight) * 1000) / + 1000 + : Math.round((1 / Math.max(1, dedupedParts.length)) * 1000) / 1000, + })); + + return { + active: enabled && parts.length > 1, + parts, + currentText: currentText || parts[0]?.text || "", + assistantText, + previousUserText, + combinedText: + parts.length <= 1 + ? parts[0]?.text || "" + : parts.map((part) => `${part.label}:\n${part.text}`).join("\n\n"), + }; +} + +function buildVectorQueryPlan( + blendPlan, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const plan = []; + let currentSegments = []; + for (const part of blendPlan?.parts || []) { + let queries = [part.text]; + if (part.kind === "currentUser" && enableMultiIntent) { + currentSegments = splitIntentSegments(part.text, { maxSegments }); + queries = uniqueStrings([ + part.text, + ...currentSegments.filter((item) => item !== part.text), + ]); + } else { + queries = uniqueStrings([part.text]); + } + plan.push({ + kind: part.kind, + label: part.label, + weight: part.weight, + queries, + }); + } + return { + plan, + currentSegments, + }; +} + +function buildLexicalQuerySources( + userMessage, + { enableMultiIntent = true, maxSegments = 4 } = {}, +) { + const currentText = normalizeQueryText(userMessage, 400); + const segments = enableMultiIntent + ? splitIntentSegments(currentText, { maxSegments }) + : []; + return { + sources: uniqueStrings([currentText, ...segments]), + segments, + }; +} + +function computeLexicalScoreForShared(node, querySources = []) { + const haystack = String( + node?.fields?.name || node?.fields?.title || node?.fields?.summary || "", + ).toLowerCase(); + if (!haystack) return 0; + for (const sourceText of querySources) { + const normalizedSource = String(sourceText || "").toLowerCase(); + if (normalizedSource && haystack.includes(normalizedSource.split(/\s+/)[0])) { + return 1; + } + } + return 0; +} + +function extractEntityAnchors(userMessage, activeNodes = []) { + const anchors = []; + const seen = new Set(); + for (const node of activeNodes) { + const candidates = [node?.fields?.name, node?.fields?.title] + .filter((value) => typeof value === "string") + .map((value) => value.trim()) + .filter((value) => value.length >= 2); + for (const candidate of candidates) { + if (!String(userMessage || "").includes(candidate)) continue; + const key = `${node.id}:${candidate}`; + if (seen.has(key)) continue; + seen.add(key); + anchors.push({ nodeId: node.id, entity: candidate }); + break; + } + } + return anchors; +} + +async function rankNodesForTaskContext({ + graph, + userMessage, + recentMessages = [], + embeddingConfig, + options = {}, +} = {}) { + const activeNodes = Array.isArray(options.activeNodes) + ? options.activeNodes.filter((node) => node && !node.archived) + : (graph?.nodes || []).filter((node) => node && !node.archived); + const topK = Math.max(1, Math.floor(Number(options.topK) || 20)); + const diffusionTopK = Math.max(1, Math.floor(Number(options.diffusionTopK) || 100)); + const enableVectorPrefilter = options.enableVectorPrefilter ?? true; + const enableGraphDiffusion = options.enableGraphDiffusion ?? true; + const enableContextQueryBlend = options.enableContextQueryBlend ?? true; + const enableMultiIntent = options.enableMultiIntent ?? true; + const multiIntentMaxSegments = Math.max( + 1, + Math.floor(Number(options.multiIntentMaxSegments) || 4), + ); + const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, { + enabled: enableContextQueryBlend, + assistantWeight: Number(options.contextAssistantWeight ?? 0.2), + previousUserWeight: Number(options.contextPreviousUserWeight ?? 0.1), + maxTextLength: Number(options.maxTextLength || 400), + }); + const queryPlan = buildVectorQueryPlan(contextQueryBlend, { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }); + const lexicalQuery = buildLexicalQuerySources( + contextQueryBlend.currentText || userMessage, + { + enableMultiIntent, + maxSegments: multiIntentMaxSegments, + }, + ); + const diagnostics = { + queryBlendActive: contextQueryBlend.active, + queryBlendParts: (contextQueryBlend.parts || []).map((part) => ({ + kind: part.kind, + label: part.label, + weight: part.weight, + text: part.text, + length: part.text.length, + })), + queryBlendWeights: Object.fromEntries( + (contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]), + ), + segmentsUsed: [...(queryPlan.currentSegments || [])], + vectorValidation: { valid: true }, + vectorHits: 0, + vectorMergedHits: 0, + seedCount: 0, + diffusionHits: 0, + temporalSyntheticEdgeCount: 0, + teleportAlpha: Number(options.teleportAlpha ?? 0.15) || 0.15, + lexicalBoostedNodes: 0, + lexicalTopHits: [], + skipReasons: [], + timings: { vector: 0, diffusion: 0 }, + }; + + let vectorResults = []; + if (enableVectorPrefilter) { + const groups = []; + for (const part of queryPlan.plan) { + for (const queryText of part.queries) { + state.vectorCalls.push({ topK, message: queryText }); + const results = [ + { nodeId: "rule-1", score: 0.9 }, + { nodeId: "rule-2", score: 0.8 }, + { nodeId: "rule-3", score: 0.7 }, + ].map((item) => ({ + ...item, + score: item.score * Math.max(0, Number(part.weight) || 0), + })); + groups.push(results); + } + } + const merged = mergeVectorResults(groups, Math.max(topK * 2, 24)); + diagnostics.vectorHits = merged.rawHitCount; + diagnostics.vectorMergedHits = merged.results.length; + vectorResults = merged.results; + } + + const exactEntityAnchors = extractEntityAnchors( + contextQueryBlend.currentText || userMessage, + activeNodes, + ); + let diffusionResults = []; + if (enableGraphDiffusion) { + const seedMap = new Map(); + for (const item of vectorResults) { + seedMap.set(item.nodeId, Math.max(seedMap.get(item.nodeId) || 0, item.score)); + } + for (const item of exactEntityAnchors) { + seedMap.set(item.nodeId, Math.max(seedMap.get(item.nodeId) || 0, 2.0)); + } + const uniqueSeeds = [...seedMap.entries()].map(([id, energy]) => ({ id, energy })); + diagnostics.seedCount = uniqueSeeds.length; + if (uniqueSeeds.length > 0) { + state.diffusionCalls.push({ + seeds: uniqueSeeds, + options: { + maxSteps: 2, + decayFactor: 0.6, + topK: diffusionTopK, + teleportAlpha: diagnostics.teleportAlpha, + }, + }); + diffusionResults = [ + { nodeId: "rule-2", energy: 1.2 }, + { nodeId: "rule-3", energy: 0.9 }, + ]; + } + } + diagnostics.diffusionHits = diffusionResults.length; + + const scoreMap = new Map(); + for (const item of vectorResults) { + scoreMap.set(item.nodeId, { + graphScore: scoreMap.get(item.nodeId)?.graphScore || 0, + vectorScore: item.score, + }); + } + for (const item of diffusionResults) { + scoreMap.set(item.nodeId, { + graphScore: item.energy, + vectorScore: scoreMap.get(item.nodeId)?.vectorScore || 0, + }); + } + if (scoreMap.size === 0) { + for (const node of activeNodes) { + scoreMap.set(node.id, { graphScore: 0, vectorScore: 0 }); + } + } + const scoredNodes = [...scoreMap.entries()].map(([nodeId, scores]) => { + const node = activeNodes.find((item) => item.id === nodeId) || null; + const lexicalScore = computeLexicalScoreForShared(node, lexicalQuery.sources); + return { + nodeId, + node, + graphScore: scores.graphScore, + vectorScore: scores.vectorScore, + lexicalScore, + finalScore: + Number(scores.graphScore || 0) + + Number(scores.vectorScore || 0) + + Number(lexicalScore || 0) + + Number(node?.importance || 0), + weightedScore: + Number(scores.graphScore || 0) + + Number(scores.vectorScore || 0) + + Number(lexicalScore || 0) + + Number(node?.importance || 0), + }; + }); + diagnostics.lexicalBoostedNodes = scoredNodes.filter( + (item) => (Number(item.lexicalScore) || 0) > 0, + ).length; + diagnostics.lexicalTopHits = scoredNodes + .filter((item) => (Number(item.lexicalScore) || 0) > 0) + .slice(0, 5) + .map((item) => ({ + nodeId: item.nodeId, + label: item.node?.fields?.name || item.node?.fields?.title || item.nodeId, + lexicalScore: item.lexicalScore, + finalScore: item.finalScore, + })); + + return { + activeNodes, + contextQueryBlend, + queryPlan, + lexicalQuery, + vectorResults, + exactEntityAnchors, + diffusionResults, + scoredNodes, + diagnostics, + }; +} + const schema = [{ id: "rule", label: "规则", alwaysInject: false }]; const state = { @@ -93,6 +572,9 @@ const graph = createGraph(); const helpers = createGraphHelpers(graph); const retrieve = await loadRetrieve({ ...helpers, + createPromptNodeReferenceMap, + getPromptNodeLabel, + rankNodesForTaskContext, STORY_TEMPORAL_BUCKETS: { CURRENT: "current", ADJACENT_PAST: "adjacentPast", @@ -290,6 +772,10 @@ const retrieve = await loadRetrieve({ describeScopeBucket(bucket = "") { return String(bucket || ""); }, + EXTRACTION_CONTEXT_REVIEW_HEADER: + "--- 以下是上下文回顾(已提取过),仅供理解剧情 ---", + RECALL_TARGET_CONTENT_HEADER: + "--- 以下是本次需要召回记忆的新对话内容 ---", buildTaskPrompt() { return { systemPrompt: "" }; }, diff --git a/tests/shared-ranking.mjs b/tests/shared-ranking.mjs new file mode 100644 index 0000000..3781175 --- /dev/null +++ b/tests/shared-ranking.mjs @@ -0,0 +1,175 @@ +import assert from "node:assert/strict"; +import { + installResolveHooks, + toDataModuleUrl, +} from "./helpers/register-hooks-compat.mjs"; + +const extensionsShimSource = [ + "export const extension_settings = {};", + "export function getContext() {", + " return {", + " chat: [],", + " chatMetadata: {},", + " extensionSettings: {},", + " powerUserSettings: {},", + " characters: {},", + " characterId: null,", + " name1: '玩家',", + " name2: '艾琳',", + " chatId: 'test-chat',", + " };", + "}", +].join("\n"); + +const scriptShimSource = [ + "export function getRequestHeaders() {", + " return {};", + "}", + "export function substituteParamsExtended(value) {", + " return String(value ?? '');", + "}", +].join("\n"); + +installResolveHooks([ + { + specifiers: [ + "../../../extensions.js", + "../../../../extensions.js", + "../../../../../extensions.js", + ], + url: toDataModuleUrl(extensionsShimSource), + }, + { + specifiers: [ + "../../../../script.js", + "../../../../../script.js", + ], + url: toDataModuleUrl(scriptShimSource), + }, +]); + +const { addEdge, addNode, createEdge, createEmptyGraph, createNode } = await import( + "../graph/graph.js" +); +const { rankNodesForTaskContext } = await import("../retrieval/shared-ranking.js"); + +function setTestOverrides(overrides = {}) { + globalThis.__stBmeTestOverrides = overrides; + return () => { + delete globalThis.__stBmeTestOverrides; + }; +} + +const graph = createEmptyGraph(); +const confession = addNode( + graph, + createNode({ + type: "event", + seq: 10, + importance: 8, + fields: { + title: "中文告白", + summary: "她认真地说喜欢你,并要求你再说一遍。", + }, + }), +); +const dateEvent = addNode( + graph, + createNode({ + type: "event", + seq: 11, + importance: 4, + fields: { + title: "节日约会", + summary: "她们一起逛街吃饭。", + }, + }), +); +const relationship = addNode( + graph, + createNode({ + type: "thread", + seq: 12, + importance: 7, + fields: { + title: "感情升温", + summary: "两人的恋爱关系快速升温。", + }, + }), +); +confession.embedding = [1, 0.3, 0.1]; +dateEvent.embedding = [0.2, 0.9, 0.1]; +relationship.embedding = [0.8, 0.6, 0.2]; +addEdge( + graph, + createEdge({ + fromId: confession.id, + toId: relationship.id, + relation: "supports", + strength: 0.9, + }), +); + +const graphBefore = JSON.stringify(graph); +const restore = setTestOverrides({ + embedding: { + async embedText() { + return [1, 0.5, 0.25]; + }, + searchSimilar(_queryVec, candidates) { + assert.ok(candidates.some((item) => item.nodeId === confession.id)); + return [ + { nodeId: confession.id, score: 0.97 }, + { nodeId: dateEvent.id, score: 0.23 }, + ]; + }, + }, +}); + +try { + const config = { + mode: "direct", + source: "direct", + apiUrl: "https://example.com/v1", + apiKey: "", + model: "test-embedding", + }; + const first = await rankNodesForTaskContext({ + graph, + userMessage: "[user]: 中文告白后的关系进展", + embeddingConfig: config, + options: { + enableContextQueryBlend: false, + topK: 8, + diffusionTopK: 16, + }, + }); + const second = await rankNodesForTaskContext({ + graph, + userMessage: "[user]: 中文告白后的关系进展", + embeddingConfig: config, + options: { + enableContextQueryBlend: false, + topK: 8, + diffusionTopK: 16, + }, + }); + + assert.equal(JSON.stringify(graph), graphBefore, "shared ranking should be side-effect-free"); + assert.equal(first.scoredNodes[0]?.nodeId, confession.id); + assert.equal(second.scoredNodes[0]?.nodeId, confession.id); + assert.deepEqual( + first.scoredNodes.map((item) => item.nodeId), + second.scoredNodes.map((item) => item.nodeId), + "ranking order should stay deterministic under fixed inputs", + ); + const propagated = first.scoredNodes.find((item) => item.nodeId === relationship.id); + assert.ok(propagated, "diffusion should surface connected relationship node"); + assert.ok((Number(propagated?.graphScore) || 0) > 0, "connected node should receive graph diffusion score"); + assert.equal(first.diagnostics.vectorMergedHits, 2); + assert.ok(first.diagnostics.diffusionHits >= 1); +} finally { + restore(); +} + +console.log("shared-ranking tests passed"); diff --git a/ui/panel.js b/ui/panel.js index c2ef630..1f11e19 100644 --- a/ui/panel.js +++ b/ui/panel.js @@ -4077,7 +4077,11 @@ function _bindActions() { const btn = document.getElementById("bme-act-extract"); if (btn?.disabled) return; const mode = - String(document.getElementById("bme-extract-mode")?.value || "pending") + String( + document.getElementById("bme-extract-mode")?.value || + (_getSettings?.() || {}).extractActionMode || + "pending", + ) .trim() .toLowerCase() === "rerun" ? "rerun" @@ -4575,6 +4579,10 @@ function _refreshConfigTab() { "bme-setting-wi-filter-keywords", settings.worldInfoFilterCustomKeywords || "", ); + _setInputValue( + "bme-extract-mode", + settings.extractActionMode || "pending", + ); const wiFilterCustomSection = panelEl?.querySelector( "#bme-wi-filter-custom-section", ); @@ -4995,6 +5003,19 @@ function _bindConfigControls() { }); noticeDisplayModeEl.dataset.bmeBound = "true"; } + const extractModeEl = document.getElementById("bme-extract-mode"); + if (extractModeEl && extractModeEl.dataset.bmeBound !== "true") { + extractModeEl.addEventListener("change", () => { + _patchSettings({ + extractActionMode: + String(extractModeEl.value || "pending").trim().toLowerCase() === + "rerun" + ? "rerun" + : "pending", + }); + }); + extractModeEl.dataset.bmeBound = "true"; + } const cloudStorageModeEl = document.getElementById( "bme-setting-cloud-storage-mode", );