From 84211d9b9d2d1531a809f011959f9b5d8c35757b Mon Sep 17 00:00:00 2001 From: Youzini-afk <13153778771cx@gmail.com> Date: Fri, 27 Mar 2026 19:43:40 +0800 Subject: [PATCH] feat: enhance recall pipeline retrieval stack --- diffusion.js | 38 +- graph.js | 54 ++- index.js | 47 +++ panel.html | 234 +++++++++++ panel.js | 143 +++++++ retrieval-enhancer.js | 795 +++++++++++++++++++++++++++++++++++ retriever.js | 458 +++++++++++++++++--- tests/default-settings.mjs | 17 + tests/graph-retrieval.mjs | 24 +- tests/retrieval-config.mjs | 75 +++- tests/retrieval-enhancer.mjs | 154 +++++++ 11 files changed, 1943 insertions(+), 96 deletions(-) create mode 100644 retrieval-enhancer.js create mode 100644 tests/retrieval-enhancer.mjs diff --git a/diffusion.js b/diffusion.js index 78dfa6e..be5bde0 100644 --- a/diffusion.js +++ b/diffusion.js @@ -38,6 +38,8 @@ const DEFAULT_OPTIONS = { minEnergy: 0.01, // 最小有效能量(低于此值视为不活跃) maxEnergy: 2.0, // 能量上限 minEnergy_clamp: -2.0, // 能量下限(抑制) + teleportAlpha: 0.0, // PPR 回拉概率 + inhibitMultiplier: 2.0, // 抑制边负向传播倍率 }; /** @@ -59,16 +61,21 @@ const DEFAULT_OPTIONS = { */ export function propagateActivation(adjacencyMap, seedNodes, options = {}) { const opts = { ...DEFAULT_OPTIONS, ...options }; + const teleportAlpha = clamp01(opts.teleportAlpha); /** @type {Map} */ let currentEnergy = new Map(); + /** @type {Map} */ + const initialEnergy = new Map(); for (const seed of seedNodes || []) { if (!seed?.id) continue; const clamped = clampEnergy(Number(seed.energy) || 0, opts); if (Math.abs(clamped) >= opts.minEnergy) { const existing = currentEnergy.get(seed.id) || 0; - currentEnergy.set(seed.id, clampEnergy(existing + clamped, opts)); + const next = clampEnergy(existing + clamped, opts); + currentEnergy.set(seed.id, next); + initialEnergy.set(seed.id, next); } } @@ -89,11 +96,18 @@ export function propagateActivation(adjacencyMap, seedNodes, options = {}) { for (const neighbor of neighbors) { if (!neighbor?.targetId) continue; let propagated = - energy * (Number(neighbor.strength) || 0) * opts.decayFactor; + energy * + (Number(neighbor.strength) || 0) * + opts.decayFactor * + (1 - teleportAlpha); // 抑制边:传递负能量 if (neighbor.edgeType === INHIBIT_EDGE_TYPE) { - propagated = -Math.abs(propagated); + propagated = + -Math.abs(energy) * + (Number(neighbor.strength) || 0) * + opts.decayFactor * + (Number(opts.inhibitMultiplier) || 1); } // 累加到邻居节点 @@ -112,6 +126,20 @@ export function propagateActivation(adjacencyMap, seedNodes, options = {}) { } } + if (teleportAlpha > 0) { + for (const [nodeId, seedEnergy] of initialEnergy) { + const current = nextEnergy.get(nodeId) || 0; + const teleported = + (1 - teleportAlpha) * current + teleportAlpha * seedEnergy; + const clamped = clampEnergy(teleported, opts); + if (Math.abs(clamped) >= opts.minEnergy) { + nextEnergy.set(nodeId, clamped); + } else { + nextEnergy.delete(nodeId); + } + } + } + // 动态剪枝:只保留 Top-K if (nextEnergy.size > opts.topK) { const sorted = [...nextEnergy.entries()].sort( @@ -152,6 +180,10 @@ function clampEnergy(energy, opts) { return Math.max(opts.minEnergy_clamp, Math.min(opts.maxEnergy, energy)); } +function clamp01(value) { + return Math.max(0, Math.min(1, Number(value) || 0)); +} + /** * 快捷方法:从种子列表创建扩散并返回按能量排序的结果 * diff --git a/graph.js b/graph.js index 2d4a462..8739088 100644 --- a/graph.js +++ b/graph.js @@ -372,11 +372,17 @@ export function buildAdjacencyMap(graph) { * @param {GraphState} graph * @returns {Map} */ -export function buildTemporalAdjacencyMap(graph) { +export function buildTemporalAdjacencyMap(graph, options = {}) { const adj = new Map(); + adj.syntheticEdgeCount = 0; const activeNodeIds = new Set( graph.nodes.filter((node) => !node.archived).map((node) => node.id), ); + const includeTemporalLinks = options.includeTemporalLinks !== false; + const temporalLinkStrength = Math.max( + 0, + Math.min(1, Number(options.temporalLinkStrength) || 0.2), + ); for (const edge of graph.edges) { if (!isEdgeActive(edge)) continue; @@ -384,24 +390,46 @@ export function buildTemporalAdjacencyMap(graph) { continue; } - if (!adj.has(edge.fromId)) adj.set(edge.fromId, []); - adj.get(edge.fromId).push({ - targetId: edge.toId, - strength: edge.strength, - edgeType: edge.edgeType, - }); + addAdjacencyPair(adj, edge.fromId, edge.toId, edge.strength, edge.edgeType); + } - if (!adj.has(edge.toId)) adj.set(edge.toId, []); - adj.get(edge.toId).push({ - targetId: edge.fromId, - strength: edge.strength, - edgeType: edge.edgeType, - }); + if (includeTemporalLinks && temporalLinkStrength > 0) { + const activeNodes = graph.nodes.filter( + (node) => !node.archived && activeNodeIds.has(node.id), + ); + const seenPairs = new Set(); + + for (const node of activeNodes) { + for (const neighborId of [node.prevId, node.nextId]) { + if (!neighborId || !activeNodeIds.has(neighborId)) continue; + const key = [node.id, neighborId].sort().join("::"); + if (seenPairs.has(key)) continue; + seenPairs.add(key); + addAdjacencyPair(adj, node.id, neighborId, temporalLinkStrength, 0); + adj.syntheticEdgeCount += 1; + } + } } return adj; } +function addAdjacencyPair(adj, fromId, toId, strength, edgeType) { + if (!adj.has(fromId)) adj.set(fromId, []); + adj.get(fromId).push({ + targetId: toId, + strength, + edgeType, + }); + + if (!adj.has(toId)) adj.set(toId, []); + adj.get(toId).push({ + targetId: fromId, + strength, + edgeType, + }); +} + function isEdgeActive(edge, now = Date.now()) { if (!edge) return false; if (edge.invalidAt && edge.invalidAt <= now) return false; diff --git a/index.js b/index.js index b127a8c..08c93db 100644 --- a/index.js +++ b/index.js @@ -173,6 +173,23 @@ const defaultSettings = { recallDiffusionTopK: 100, // 图扩散阶段保留的候选上限 recallLlmCandidatePool: 30, // 传给 LLM 精排的候选池大小 recallLlmContextMessages: 4, // 传给 LLM 精排的最近非系统消息数 + recallEnableMultiIntent: true, + recallMultiIntentMaxSegments: 4, + recallTeleportAlpha: 0.15, + recallEnableTemporalLinks: true, + recallTemporalLinkStrength: 0.2, + recallEnableDiversitySampling: true, + recallDppCandidateMultiplier: 3, + recallDppQualityWeight: 1.0, + recallEnableCooccurrenceBoost: false, + recallCooccurrenceScale: 0.1, + recallCooccurrenceMaxNeighbors: 10, + recallEnableResidualRecall: false, + recallResidualBasisMaxNodes: 24, + recallNmfTopics: 15, + recallNmfNoveltyThreshold: 0.4, + recallResidualThreshold: 0.3, + recallResidualTopK: 5, // 注入设置 injectPosition: "atDepth", // 注入位置 @@ -3637,7 +3654,13 @@ function applyRecallInjection(settings, recallInput, recentMessages, result) { recallInput.sourceLabel, `ctx ${recentMessages.length}`, `vector ${retrievalMeta.vectorHits ?? 0}`, + retrievalMeta.vectorMergedHits + ? `merged ${retrievalMeta.vectorMergedHits}` + : "", `diffusion ${retrievalMeta.diffusionHits ?? 0}`, + retrievalMeta.candidatePoolAfterDpp + ? `dpp ${retrievalMeta.candidatePoolAfterDpp}` + : "", `llm pool ${llmMeta.candidatePool ?? 0}`, `recall ${result.stats.recallCount}`, ] @@ -3782,6 +3805,30 @@ async function runRecall(options = {}) { enableCrossRecall: settings.enableCrossRecall ?? false, enableProbRecall: settings.enableProbRecall ?? false, probRecallChance: settings.probRecallChance ?? 0.15, + enableMultiIntent: settings.recallEnableMultiIntent ?? true, + multiIntentMaxSegments: settings.recallMultiIntentMaxSegments ?? 4, + teleportAlpha: settings.recallTeleportAlpha ?? 0.15, + enableTemporalLinks: settings.recallEnableTemporalLinks ?? true, + temporalLinkStrength: settings.recallTemporalLinkStrength ?? 0.2, + enableDiversitySampling: + settings.recallEnableDiversitySampling ?? true, + dppCandidateMultiplier: + settings.recallDppCandidateMultiplier ?? 3, + dppQualityWeight: settings.recallDppQualityWeight ?? 1.0, + enableCooccurrenceBoost: + settings.recallEnableCooccurrenceBoost ?? false, + cooccurrenceScale: settings.recallCooccurrenceScale ?? 0.1, + cooccurrenceMaxNeighbors: + settings.recallCooccurrenceMaxNeighbors ?? 10, + enableResidualRecall: + settings.recallEnableResidualRecall ?? false, + residualBasisMaxNodes: + settings.recallResidualBasisMaxNodes ?? 24, + residualNmfTopics: settings.recallNmfTopics ?? 15, + residualNmfNoveltyThreshold: + settings.recallNmfNoveltyThreshold ?? 0.4, + residualThreshold: settings.recallResidualThreshold ?? 0.3, + residualTopK: settings.recallResidualTopK ?? 5, }, }); diff --git a/panel.html b/panel.html index f6420b6..a00af7f 100644 --- a/panel.html +++ b/panel.html @@ -1095,6 +1095,151 @@ +
+
+
+
召回增强
+
+ 调整种子构建、扩散回拉、多样性去重和共现补强。 +
+
+
+ 在“功能开关”中启用后生效。 +
+
+ +
+ + +
+
+ + +
+ +
+ + +
+ +
+ + +
+
+ + +
+ +
+ + +
+
+ + +
+
+
+
+
+
+
弱信号召回
+
+ 仅在直连 embedding 且本地有足够向量时使用,用于补抓被主主题压住的弱线索。 +
+
+
+ 在“功能开关”中启用后生效。 +
+
+ +
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
{ + _patchSettings({ recallEnableMultiIntent: checked }); + }); + bindCheckbox("bme-setting-recall-temporal-links-enabled", (checked) => { + _patchSettings({ recallEnableTemporalLinks: checked }); + }); + bindCheckbox("bme-setting-recall-diversity-enabled", (checked) => { + _patchSettings({ recallEnableDiversitySampling: checked }); + }); + bindCheckbox("bme-setting-recall-cooccurrence-enabled", (checked) => { + _patchSettings({ recallEnableCooccurrenceBoost: checked }); + }); + bindCheckbox("bme-setting-recall-residual-enabled", (checked) => { + _patchSettings({ recallEnableResidualRecall: checked }); + }); bindCheckbox("bme-setting-consolidation-enabled", (checked) => { _patchSettings({ enableConsolidation: checked }); _refreshGuardedConfigStates(); @@ -1395,6 +1478,66 @@ function _bindConfigControls() { bindNumber("bme-setting-recall-llm-context-messages", 4, 0, 20, (value) => _patchSettings({ recallLlmContextMessages: value }), ); + bindNumber( + "bme-setting-recall-multi-intent-max-segments", + 4, + 1, + 8, + (value) => _patchSettings({ recallMultiIntentMaxSegments: value }), + ); + bindFloat("bme-setting-recall-teleport-alpha", 0.15, 0, 1, (value) => + _patchSettings({ recallTeleportAlpha: value }), + ); + bindFloat( + "bme-setting-recall-temporal-link-strength", + 0.2, + 0, + 1, + (value) => _patchSettings({ recallTemporalLinkStrength: value }), + ); + bindNumber( + "bme-setting-recall-dpp-candidate-multiplier", + 3, + 1, + 10, + (value) => _patchSettings({ recallDppCandidateMultiplier: value }), + ); + bindFloat("bme-setting-recall-dpp-quality-weight", 1.0, 0, 10, (value) => + _patchSettings({ recallDppQualityWeight: value }), + ); + bindFloat("bme-setting-recall-cooccurrence-scale", 0.1, 0, 10, (value) => + _patchSettings({ recallCooccurrenceScale: value }), + ); + bindNumber( + "bme-setting-recall-cooccurrence-max-neighbors", + 10, + 1, + 50, + (value) => _patchSettings({ recallCooccurrenceMaxNeighbors: value }), + ); + bindNumber( + "bme-setting-recall-residual-basis-max-nodes", + 24, + 2, + 64, + (value) => _patchSettings({ recallResidualBasisMaxNodes: value }), + ); + bindNumber("bme-setting-recall-nmf-topics", 15, 2, 64, (value) => + _patchSettings({ recallNmfTopics: value }), + ); + bindFloat( + "bme-setting-recall-nmf-novelty-threshold", + 0.4, + 0, + 1, + (value) => _patchSettings({ recallNmfNoveltyThreshold: value }), + ); + bindFloat("bme-setting-recall-residual-threshold", 0.3, 0, 10, (value) => + _patchSettings({ recallResidualThreshold: value }), + ); + bindNumber("bme-setting-recall-residual-top-k", 5, 1, 20, (value) => + _patchSettings({ recallResidualTopK: value }), + ); bindNumber("bme-setting-inject-depth", 9999, 0, 9999, (value) => _patchSettings({ injectDepth: value }), ); diff --git a/retrieval-enhancer.js b/retrieval-enhancer.js new file mode 100644 index 0000000..c8878f8 --- /dev/null +++ b/retrieval-enhancer.js @@ -0,0 +1,795 @@ +import { embedText, searchSimilar } from "./embedding.js"; +import { getNode } from "./graph.js"; +import { isDirectVectorConfig } from "./vector-index.js"; + +const COOCCURRENCE_EXCLUDED_TYPES = new Set([ + "event", + "synopsis", + "reflection", +]); + +const cooccurrenceCache = new WeakMap(); + +export function splitIntentSegments( + text, + { maxSegments = 4, minLength = 3 } = {}, +) { + const raw = String(text || "").trim(); + if (!raw) return []; + + const segments = raw + .split(/[,,。.;;!!??\n]+|(?:顺便|另外|还有|对了|然后|而且|并且|同时)/) + .map((item) => item.trim()) + .filter((item) => item.length >= minLength); + + return uniqueStrings(segments).slice(0, Math.max(1, maxSegments)); +} + +export function mergeVectorResults(resultGroups = [], limit = Infinity) { + const merged = new Map(); + let rawHitCount = 0; + + for (const group of resultGroups) { + for (const item of Array.isArray(group) ? group : []) { + if (!item?.nodeId) continue; + rawHitCount += 1; + const score = Number(item.score) || 0; + const existing = merged.get(item.nodeId); + if (!existing || score > existing.score) { + merged.set(item.nodeId, { ...item, score }); + } + } + } + + const results = [...merged.values()] + .sort((a, b) => { + if (b.score !== a.score) return b.score - a.score; + return String(a.nodeId).localeCompare(String(b.nodeId)); + }) + .slice(0, Number.isFinite(limit) ? limit : merged.size); + + return { + rawHitCount, + results, + }; +} + +export function isEligibleAnchorNode(node) { + if (!node || node.archived) return false; + if (COOCCURRENCE_EXCLUDED_TYPES.has(node.type)) return false; + return getAnchorTerms(node).length > 0; +} + +export function getAnchorTerms(node) { + return [node?.fields?.name, node?.fields?.title] + .filter((value) => typeof value === "string") + .map((value) => value.trim()) + .filter((value) => value.length >= 2); +} + +export function collectSupplementalAnchorNodeIds( + graph, + vectorResults = [], + primaryAnchorIds = [], + maxCount = 5, +) { + const selected = []; + const seen = new Set(primaryAnchorIds || []); + + for (const result of vectorResults) { + if (selected.length >= maxCount) break; + const node = getNode(graph, result?.nodeId); + if (!isEligibleAnchorNode(node) || seen.has(node.id)) continue; + seen.add(node.id); + selected.push(node.id); + } + + return selected; +} + +export function createCooccurrenceIndex( + graph, + { + maxAnchorsPerBatch = 10, + eligibleNodes = null, + } = {}, +) { + const nodes = Array.isArray(eligibleNodes) + ? eligibleNodes.filter(isEligibleAnchorNode) + : []; + const eligibleNodeKey = nodes.map((node) => node.id).sort().join("|"); + const cacheKey = [ + graph?.batchJournal?.length || 0, + graph?.nodes?.length || 0, + graph?.historyState?.lastProcessedAssistantFloor ?? -1, + maxAnchorsPerBatch, + eligibleNodeKey, + ].join(":"); + const cached = cooccurrenceCache.get(graph); + if (cached?.key === cacheKey) { + return cached.value; + } + + const index = new Map(); + let pairCount = 0; + let batchCount = 0; + let source = "seqRange"; + + if (nodes.length >= 2 && Array.isArray(graph?.batchJournal)) { + for (const journal of graph.batchJournal) { + const range = Array.isArray(journal?.processedRange) + ? journal.processedRange + : null; + if (!range || !Number.isFinite(range[0]) || !Number.isFinite(range[1])) { + continue; + } + + const batchNodes = nodes + .filter((node) => rangesOverlap(node.seqRange, range)) + .sort(compareBySeqDesc) + .slice(0, Math.max(2, maxAnchorsPerBatch)); + if (batchNodes.length < 2) continue; + + batchCount += 1; + pairCount += appendPairs(index, batchNodes, 1); + } + } + + if (batchCount === 0) { + source = "seqRange"; + pairCount = 0; + index.clear(); + + for (let i = 0; i < nodes.length; i++) { + for (let j = i + 1; j < nodes.length; j++) { + const overlap = rangeOverlapSize(nodes[i].seqRange, nodes[j].seqRange); + if (overlap <= 0) continue; + addCooccurrence(index, nodes[i].id, nodes[j].id, overlap); + addCooccurrence(index, nodes[j].id, nodes[i].id, overlap); + pairCount += 1; + } + } + } else { + source = "batchJournal"; + } + + const result = { + map: normalizeCooccurrenceMap(index), + source, + batchCount, + pairCount, + }; + cooccurrenceCache.set(graph, { key: cacheKey, value: result }); + return result; +} + +export function applyCooccurrenceBoost( + baseScores, + anchorWeights, + cooccurrenceIndex, + { scale = 0.1, maxNeighbors = 10 } = {}, +) { + const nextScores = new Map(baseScores || []); + const boostedNodes = []; + const map = cooccurrenceIndex?.map instanceof Map + ? cooccurrenceIndex.map + : new Map(); + + for (const [anchorId, anchorScore] of anchorWeights.entries()) { + const neighbors = map.get(anchorId) || []; + const capped = neighbors.slice(0, Math.max(1, maxNeighbors)); + + for (const item of capped) { + const bonus = + Math.max(0, Number(anchorScore) || 0) * + Math.log(1 + Math.max(0, Number(item.count) || 0)) * + Math.max(0, Number(scale) || 0); + if (!bonus) continue; + + nextScores.set(item.nodeId, (nextScores.get(item.nodeId) || 0) + bonus); + boostedNodes.push({ + anchorId, + nodeId: item.nodeId, + count: item.count, + bonus, + }); + } + } + + return { + scores: nextScores, + boostedNodes, + }; +} + +export function dppGreedySelect( + candidateVecs = [], + candidateScores = [], + k, + qualityWeight = 1, +) { + const total = Math.min(candidateVecs.length, candidateScores.length); + const target = Math.max(0, Math.min(k, total)); + if (target >= total) { + return Array.from({ length: total }, (_, index) => index); + } + + const normalized = candidateVecs.map((vector) => normalizeVector(vector)); + const q = candidateScores.map((score) => + Math.pow(Math.max(Number(score) || 0, 1e-10), Math.max(0, qualityWeight)), + ); + const diag = q.map((value) => value * value + 1e-8); + const chol = Array.from({ length: target }, () => + Array(total).fill(0), + ); + const selected = []; + + for (let j = 0; j < target; j++) { + let bestIndex = -1; + let bestValue = Number.NEGATIVE_INFINITY; + + for (let i = 0; i < total; i++) { + if (selected.includes(i)) continue; + if (diag[i] > bestValue) { + bestValue = diag[i]; + bestIndex = i; + } + } + + if (bestIndex === -1) break; + selected.push(bestIndex); + + if (j === target - 1 || diag[bestIndex] < 1e-10) { + continue; + } + + const row = normalized.map( + (vector, index) => q[bestIndex] * dot(normalized[bestIndex], vector) * q[index], + ); + const next = [...row]; + for (let i = 0; i < j; i++) { + const pivot = chol[i][bestIndex]; + for (let index = 0; index < total; index++) { + next[index] -= pivot * chol[i][index]; + } + } + + const inv = 1 / Math.sqrt(diag[bestIndex]); + for (let index = 0; index < total; index++) { + chol[j][index] = next[index] * inv; + diag[index] = Math.max(0, diag[index] - chol[j][index] ** 2); + } + } + + return selected; +} + +export function applyDiversitySampling( + candidates = [], + { k, qualityWeight = 1 } = {}, +) { + const target = Math.max(1, Math.floor(Number(k) || 0)); + if (candidates.length <= target) { + return { + applied: false, + reason: "candidate-pool-too-small", + selected: candidates.slice(0, target), + beforeCount: candidates.length, + afterCount: Math.min(candidates.length, target), + }; + } + + if ( + candidates.some( + (item) => + !Array.isArray(item?.node?.embedding) || item.node.embedding.length === 0, + ) + ) { + return { + applied: false, + reason: "candidate-embeddings-missing", + selected: candidates.slice(0, target), + beforeCount: candidates.length, + afterCount: Math.min(candidates.length, target), + }; + } + + const indexes = dppGreedySelect( + candidates.map((item) => item.node.embedding), + candidates.map((item) => item.finalScore), + target, + qualityWeight, + ); + + const selected = indexes + .map((index) => candidates[index]) + .filter(Boolean); + + if (selected.length !== target) { + return { + applied: false, + reason: "dpp-selection-incomplete", + selected: candidates.slice(0, target), + beforeCount: candidates.length, + afterCount: Math.min(candidates.length, target), + }; + } + + return { + applied: true, + reason: "", + selected, + beforeCount: candidates.length, + afterCount: selected.length, + }; +} + +export function nmfQueryAnalysis( + queryVec, + entityVecs, + { nTopics = 15, maxIter = 100, tolerance = 1e-4 } = {}, +) { + const vectors = normalizeMatrix(entityVecs); + const query = vectorAbs(queryVec); + if (vectors.length < 2 || query.length === 0) { + return { + semanticDepth: 0, + topicCoverage: 0, + novelty: 1, + topTopics: [], + }; + } + + const k = Math.min(Math.max(1, Math.floor(nTopics)), vectors.length); + const matrix = vectors.map((vector) => vectorAbs(vector)); + const { h } = nmfMultiplicativeUpdate(matrix, k, maxIter, tolerance); + const rawScores = h.map((topic) => dot(query, topic)); + const topics = softmax(rawScores); + + const entropy = -topics.reduce((sum, value) => { + return value > 1e-10 ? sum + value * Math.log(value) : sum; + }, 0); + const maxEntropy = k > 1 ? Math.log(k) : 1; + const semanticDepth = 1 - entropy / maxEntropy; + const topicCoverage = topics.filter((value) => value > 0.5 / k).length; + const reconstruction = Array(query.length).fill(0); + + for (let topicIndex = 0; topicIndex < topics.length; topicIndex++) { + const weight = topics[topicIndex]; + for (let dim = 0; dim < reconstruction.length; dim++) { + reconstruction[dim] += weight * h[topicIndex][dim]; + } + } + + const novelty = + l2Norm(subtractVectors(query, reconstruction)) / Math.max(l2Norm(query), 1e-10); + + return { + semanticDepth, + topicCoverage, + novelty, + topTopics: topics, + }; +} + +export function sparseCodeResidual( + queryVec, + entityVecs, + { lambda = 0.1, maxIter = 80 } = {}, +) { + const query = normalizeVector(queryVec, false); + const entities = normalizeMatrix(entityVecs); + const total = entities.length; + if (total === 0 || query.length === 0) { + return { + alpha: [], + residual: [...query], + residualNorm: l2Norm(query), + }; + } + + const gram = Array.from({ length: total }, () => Array(total).fill(0)); + const etq = Array(total).fill(0); + + for (let i = 0; i < total; i++) { + etq[i] = dot(entities[i], query); + for (let j = i; j < total; j++) { + const value = dot(entities[i], entities[j]); + gram[i][j] = value; + gram[j][i] = value; + } + } + + let lipschitz = 0; + for (let i = 0; i < total; i++) { + const rowSum = gram[i].reduce((sum, value) => sum + Math.abs(value), 0); + lipschitz = Math.max(lipschitz, rowSum); + } + if (lipschitz < 1e-10) { + return { + alpha: Array(total).fill(0), + residual: [...query], + residualNorm: l2Norm(query), + }; + } + + const step = 1 / lipschitz; + let alpha = Array(total).fill(0); + let y = [...alpha]; + let t = 1; + + for (let iteration = 0; iteration < maxIter; iteration++) { + const grad = matVecMul(gram, y).map((value, index) => value - etq[index]); + const nextAlpha = softThreshold( + y.map((value, index) => value - step * grad[index]), + lambda * step, + ); + const nextT = (1 + Math.sqrt(1 + 4 * t * t)) / 2; + const momentum = (t - 1) / nextT; + y = nextAlpha.map( + (value, index) => value + momentum * (value - alpha[index]), + ); + alpha = nextAlpha; + t = nextT; + } + + const reconstruction = Array(query.length).fill(0); + for (let i = 0; i < total; i++) { + if (Math.abs(alpha[i]) < 1e-10) continue; + for (let dim = 0; dim < query.length; dim++) { + reconstruction[dim] += alpha[i] * entities[i][dim]; + } + } + + const residual = subtractVectors(query, reconstruction); + return { + alpha, + residual, + residualNorm: l2Norm(residual), + }; +} + +export async function runResidualRecall({ + queryText, + graph, + embeddingConfig, + basisNodes = [], + candidateNodes = [], + basisLimit = 24, + nTopics = 15, + noveltyThreshold = 0.4, + residualThreshold = 0.3, + residualTopK = 5, + signal, +}) { + if (!isDirectVectorConfig(embeddingConfig)) { + return { + triggered: false, + hits: [], + skipReason: "residual-direct-mode-required", + }; + } + + const filteredBasis = basisNodes + .filter( + (node) => + Array.isArray(node?.embedding) && node.embedding.length > 0, + ) + .slice(0, Math.max(2, basisLimit)); + if (filteredBasis.length < 2) { + return { + triggered: false, + hits: [], + skipReason: "residual-basis-insufficient", + }; + } + + const queryVec = await embedText(queryText, embeddingConfig, { signal }); + if (!queryVec || queryVec.length === 0) { + return { + triggered: false, + hits: [], + skipReason: "residual-query-embedding-missing", + }; + } + + const nmfResult = nmfQueryAnalysis(queryVec, filteredBasis.map((node) => node.embedding), { + nTopics, + }); + if (!Number.isFinite(nmfResult.novelty) || nmfResult.novelty < noveltyThreshold) { + return { + triggered: false, + hits: [], + nmf: nmfResult, + skipReason: "residual-novelty-below-threshold", + }; + } + + const sparse = sparseCodeResidual(queryVec, filteredBasis.map((node) => node.embedding)); + if (!Number.isFinite(sparse.residualNorm) || sparse.residualNorm <= residualThreshold) { + return { + triggered: false, + hits: [], + nmf: nmfResult, + sparse, + skipReason: "residual-norm-below-threshold", + }; + } + + const searchableCandidates = (candidateNodes || []) + .filter( + (node) => + Array.isArray(node?.embedding) && + node.embedding.length > 0 && + !filteredBasis.some((basisNode) => basisNode.id === node.id), + ) + .map((node) => ({ + nodeId: node.id, + embedding: node.embedding, + })); + + if (searchableCandidates.length === 0) { + return { + triggered: true, + hits: [], + nmf: nmfResult, + sparse, + skipReason: "residual-search-space-empty", + }; + } + + const hits = searchSimilar(sparse.residual, searchableCandidates, residualTopK) + .map((item) => ({ + ...item, + node: getNode(graph, item.nodeId), + })) + .filter((item) => item.node); + + return { + triggered: true, + hits, + nmf: nmfResult, + sparse, + skipReason: hits.length > 0 ? "" : "residual-no-hit", + }; +} + +function uniqueStrings(items = []) { + return [...new Set(items.filter(Boolean))]; +} + +function normalizeCooccurrenceMap(index) { + const normalized = new Map(); + for (const [nodeId, neighborMap] of index.entries()) { + normalized.set( + nodeId, + [...neighborMap.entries()] + .map(([neighborId, count]) => ({ nodeId: neighborId, count })) + .sort((a, b) => { + if (b.count !== a.count) return b.count - a.count; + return String(a.nodeId).localeCompare(String(b.nodeId)); + }), + ); + } + return normalized; +} + +function appendPairs(index, nodes, increment) { + let count = 0; + for (let i = 0; i < nodes.length; i++) { + for (let j = i + 1; j < nodes.length; j++) { + addCooccurrence(index, nodes[i].id, nodes[j].id, increment); + addCooccurrence(index, nodes[j].id, nodes[i].id, increment); + count += 1; + } + } + return count; +} + +function addCooccurrence(index, fromId, toId, increment) { + if (!index.has(fromId)) { + index.set(fromId, new Map()); + } + const map = index.get(fromId); + map.set(toId, (map.get(toId) || 0) + increment); +} + +function rangesOverlap(a, b) { + return rangeOverlapSize(a, b) > 0; +} + +function rangeOverlapSize(a, b) { + const rangeA = normalizeRange(a); + const rangeB = normalizeRange(b); + if (!rangeA || !rangeB) return 0; + const start = Math.max(rangeA[0], rangeB[0]); + const end = Math.min(rangeA[1], rangeB[1]); + return end >= start ? end - start + 1 : 0; +} + +function normalizeRange(range) { + if (!Array.isArray(range) || range.length < 2) return null; + const start = Number(range[0]); + const end = Number(range[1]); + if (!Number.isFinite(start) || !Number.isFinite(end)) return null; + return [Math.min(start, end), Math.max(start, end)]; +} + +function compareBySeqDesc(a, b) { + const seqA = a?.seqRange?.[1] ?? a?.seq ?? 0; + const seqB = b?.seqRange?.[1] ?? b?.seq ?? 0; + if (seqB !== seqA) return seqB - seqA; + return (b.importance || 0) - (a.importance || 0); +} + +function vectorAbs(vector = []) { + return vector.map((value) => Math.abs(Number(value) || 0)); +} + +function normalizeVector(vector = [], useUnitNorm = true) { + const normalized = vector.map((value) => Number(value) || 0); + if (!useUnitNorm) return normalized; + const norm = l2Norm(normalized); + if (norm < 1e-10) return normalized.map(() => 0); + return normalized.map((value) => value / norm); +} + +function normalizeMatrix(vectors = []) { + return vectors + .filter((vector) => Array.isArray(vector) && vector.length > 0) + .map((vector) => normalizeVector(vector)); +} + +function dot(a = [], b = []) { + const length = Math.min(a.length, b.length); + let sum = 0; + for (let index = 0; index < length; index++) { + sum += (Number(a[index]) || 0) * (Number(b[index]) || 0); + } + return sum; +} + +function l2Norm(vector = []) { + return Math.sqrt(vector.reduce((sum, value) => sum + value * value, 0)); +} + +function subtractVectors(a = [], b = []) { + const length = Math.max(a.length, b.length); + const result = Array(length).fill(0); + for (let index = 0; index < length; index++) { + result[index] = (Number(a[index]) || 0) - (Number(b[index]) || 0); + } + return result; +} + +function matVecMul(matrix = [], vector = []) { + return matrix.map((row) => dot(row, vector)); +} + +function softThreshold(vector = [], threshold = 0) { + return vector.map((value) => { + const absValue = Math.abs(value); + if (absValue <= threshold) return 0; + return Math.sign(value) * (absValue - threshold); + }); +} + +function softmax(values = []) { + if (values.length === 0) return []; + const max = Math.max(...values); + const exp = values.map((value) => Math.exp(value - max)); + const total = exp.reduce((sum, value) => sum + value, 0) || 1; + return exp.map((value) => value / total); +} + +function nmfMultiplicativeUpdate(matrix, k, maxIter, tolerance) { + const m = matrix.length; + const d = matrix[0]?.length || 0; + const mean = + matrix.reduce((sum, row) => sum + row.reduce((acc, value) => acc + value, 0), 0) / + Math.max(1, m * d) || 0.01; + const avg = Math.max(Math.sqrt(mean / Math.max(1, k)), 0.01); + const rand = createDeterministicRandom(42); + const w = Array.from({ length: m }, () => + Array.from({ length: k }, () => Math.abs(avg + avg * 0.5 * (rand() - 0.5)) + 1e-6), + ); + const h = Array.from({ length: k }, () => + Array.from({ length: d }, () => Math.abs(avg + avg * 0.5 * (rand() - 0.5)) + 1e-6), + ); + const eps = 1e-10; + + for (let iteration = 0; iteration < maxIter; iteration++) { + const wtV = Array.from({ length: k }, () => Array(d).fill(0)); + const wtW = Array.from({ length: k }, () => Array(k).fill(0)); + + for (let i = 0; i < k; i++) { + for (let dim = 0; dim < d; dim++) { + let sum = 0; + for (let row = 0; row < m; row++) { + sum += w[row][i] * matrix[row][dim]; + } + wtV[i][dim] = sum; + } + for (let j = 0; j < k; j++) { + let sum = 0; + for (let row = 0; row < m; row++) { + sum += w[row][i] * w[row][j]; + } + wtW[i][j] = sum; + } + } + + for (let i = 0; i < k; i++) { + for (let dim = 0; dim < d; dim++) { + let denominator = 0; + for (let topic = 0; topic < k; topic++) { + denominator += wtW[i][topic] * h[topic][dim]; + } + h[i][dim] *= wtV[i][dim] / (denominator + eps); + } + } + + const vHt = Array.from({ length: m }, () => Array(k).fill(0)); + const hHt = Array.from({ length: k }, () => Array(k).fill(0)); + + for (let row = 0; row < m; row++) { + for (let topic = 0; topic < k; topic++) { + let sum = 0; + for (let dim = 0; dim < d; dim++) { + sum += matrix[row][dim] * h[topic][dim]; + } + vHt[row][topic] = sum; + } + } + + for (let i = 0; i < k; i++) { + for (let j = 0; j < k; j++) { + let sum = 0; + for (let dim = 0; dim < d; dim++) { + sum += h[i][dim] * h[j][dim]; + } + hHt[i][j] = sum; + } + } + + for (let row = 0; row < m; row++) { + for (let topic = 0; topic < k; topic++) { + let denominator = 0; + for (let inner = 0; inner < k; inner++) { + denominator += w[row][inner] * hHt[inner][topic]; + } + w[row][topic] *= vHt[row][topic] / (denominator + eps); + } + } + + if (iteration % 10 === 9) { + let residualSq = 0; + let matrixSq = 0; + for (let row = 0; row < m; row++) { + for (let dim = 0; dim < d; dim++) { + let reconstructed = 0; + for (let topic = 0; topic < k; topic++) { + reconstructed += w[row][topic] * h[topic][dim]; + } + const diff = matrix[row][dim] - reconstructed; + residualSq += diff * diff; + matrixSq += matrix[row][dim] * matrix[row][dim]; + } + } + + if (matrixSq > 0 && Math.sqrt(residualSq / matrixSq) < tolerance) { + break; + } + } + } + + return { w, h }; +} + +function createDeterministicRandom(seed) { + let current = seed >>> 0; + return () => { + current = (1664525 * current + 1013904223) >>> 0; + return current / 0xffffffff; + }; +} diff --git a/retriever.js b/retriever.js index 2e09937..a28d340 100644 --- a/retriever.js +++ b/retriever.js @@ -16,6 +16,16 @@ import { buildTaskLlmPayload, buildTaskPrompt, } from "./prompt-builder.js"; +import { + applyCooccurrenceBoost, + applyDiversitySampling, + collectSupplementalAnchorNodeIds, + createCooccurrenceIndex, + isEligibleAnchorNode, + mergeVectorResults, + runResidualRecall, + splitIntentSegments, +} from "./retrieval-enhancer.js"; import { applyTaskRegex } from "./task-regex.js"; import { getSTContextForPrompt } from "./st-context.js"; import { findSimilarNodesByText, validateVectorConfig } from "./vector-index.js"; @@ -59,6 +69,65 @@ function throwIfAborted(signal) { } } +function nowMs() { + return typeof performance !== "undefined" && performance?.now + ? performance.now() + : Date.now(); +} + +function roundMs(value) { + return Math.round((Number(value) || 0) * 10) / 10; +} + +function pushSkipReason(meta, reason) { + if (!reason) return; + if (!Array.isArray(meta.skipReasons)) { + meta.skipReasons = []; + } + if (!meta.skipReasons.includes(reason)) { + meta.skipReasons.push(reason); + } +} + +function createRetrievalMeta(enableLLMRecall) { + return { + vectorHits: 0, + diffusionHits: 0, + scoredCandidates: 0, + segmentsUsed: [], + vectorMergedHits: 0, + seedCount: 0, + temporalSyntheticEdgeCount: 0, + teleportAlpha: 0, + cooccurrenceBoostedNodes: 0, + candidatePoolBeforeDpp: 0, + candidatePoolAfterDpp: 0, + diversityApplied: false, + residualTriggered: false, + residualHits: 0, + skipReasons: [], + timings: {}, + llm: { + enabled: enableLLMRecall, + status: enableLLMRecall ? "pending" : "disabled", + reason: enableLLMRecall ? "" : "LLM 精排已关闭", + candidatePool: 0, + selectedSeedCount: 0, + }, + }; +} + +function clampPositiveInt(value, fallback, min = 1) { + const parsed = Math.floor(Number(value)); + return Number.isFinite(parsed) && parsed >= min ? parsed : fallback; +} + +function clampRange(value, fallback, min = 0, max = 1) { + const parsed = Number(value); + if (!Number.isFinite(parsed)) return fallback; + return Math.max(min, Math.min(max, parsed)); +} + /** * 三层混合检索管线 * @@ -83,21 +152,71 @@ export async function retrieve({ onStreamProgress = null, }) { throwIfAborted(signal); - const topK = options.topK ?? 20; - const maxRecallNodes = options.maxRecallNodes ?? 8; + const startedAt = nowMs(); + const topK = clampPositiveInt(options.topK, 20); + const maxRecallNodes = clampPositiveInt(options.maxRecallNodes, 8); const enableLLMRecall = options.enableLLMRecall ?? true; const enableVectorPrefilter = options.enableVectorPrefilter ?? true; const enableGraphDiffusion = options.enableGraphDiffusion ?? true; - const diffusionTopK = options.diffusionTopK ?? 100; - const llmCandidatePool = options.llmCandidatePool ?? 30; + const diffusionTopK = clampPositiveInt(options.diffusionTopK, 100); + const llmCandidatePool = clampPositiveInt(options.llmCandidatePool, 30); const weights = options.weights ?? {}; - - // v2 options const enableVisibility = options.enableVisibility ?? false; const visibilityFilter = options.visibilityFilter ?? null; const enableCrossRecall = options.enableCrossRecall ?? false; const enableProbRecall = options.enableProbRecall ?? false; const probRecallChance = options.probRecallChance ?? 0.15; + const enableMultiIntent = options.enableMultiIntent ?? true; + const multiIntentMaxSegments = clampPositiveInt( + options.multiIntentMaxSegments, + 4, + ); + const teleportAlpha = clampRange(options.teleportAlpha, 0.15); + const enableTemporalLinks = options.enableTemporalLinks ?? true; + const temporalLinkStrength = clampRange( + options.temporalLinkStrength, + 0.2, + ); + const enableDiversitySampling = options.enableDiversitySampling ?? true; + const dppCandidateMultiplier = clampPositiveInt( + options.dppCandidateMultiplier, + 3, + ); + const dppQualityWeight = clampRange( + options.dppQualityWeight, + 1.0, + 0, + 10, + ); + const enableCooccurrenceBoost = options.enableCooccurrenceBoost ?? false; + const cooccurrenceScale = clampRange( + options.cooccurrenceScale, + 0.1, + 0, + 10, + ); + const cooccurrenceMaxNeighbors = clampPositiveInt( + options.cooccurrenceMaxNeighbors, + 10, + ); + const enableResidualRecall = options.enableResidualRecall ?? false; + const residualBasisMaxNodes = clampPositiveInt( + options.residualBasisMaxNodes, + 24, + 2, + ); + const residualNmfTopics = clampPositiveInt(options.residualNmfTopics, 15); + const residualNmfNoveltyThreshold = clampRange( + options.residualNmfNoveltyThreshold, + 0.4, + ); + const residualThreshold = clampRange( + options.residualThreshold, + 0.3, + 0, + 10, + ); + const residualTopK = clampPositiveInt(options.residualTopK, 5); let activeNodes = getActiveNodes(graph).filter( (node) => @@ -106,7 +225,6 @@ export async function retrieve({ Number.isFinite(node.seqRange[1]), ); - // v2 ⑦: 认知边界过滤(RoleRAG 启发) if (enableVisibility && visibilityFilter) { activeNodes = filterByVisibility(activeNodes, visibilityFilter); } @@ -119,66 +237,124 @@ export async function retrieve({ normalizedMaxRecallNodes, llmCandidatePool, ); + const vectorValidation = validateVectorConfig(embeddingConfig); + const retrievalMeta = createRetrievalMeta(enableLLMRecall); console.log( `[ST-BME] 检索开始: ${nodeCount} 个活跃节点${enableVisibility ? " (认知边界已启用)" : ""}`, ); let vectorResults = []; let diffusionResults = []; - let useLLM = false; - let llmMeta = { - enabled: enableLLMRecall, - status: enableLLMRecall ? "pending" : "disabled", - reason: enableLLMRecall ? "" : "LLM 精排已关闭", - candidatePool: 0, - selectedSeedCount: 0, - }; + let llmMeta = { ...retrievalMeta.llm }; + const exactEntityAnchors = []; + let supplementalAnchorNodeIds = []; if (nodeCount === 0) { return buildResult(graph, [], schema, { retrieval: { - vectorHits: 0, - diffusionHits: 0, - scoredCandidates: 0, + ...retrievalMeta, llm: { ...llmMeta, status: enableLLMRecall ? "skipped" : "disabled", reason: "当前没有可参与召回的活跃节点", }, + timings: { + total: roundMs(nowMs() - startedAt), + }, }, }); } - // ========== 第 1 层:向量预筛 ========== - if ( - enableVectorPrefilter && - validateVectorConfig(embeddingConfig).valid - ) { + const vectorStartedAt = nowMs(); + if (enableVectorPrefilter && vectorValidation.valid) { console.log("[ST-BME] 第1层: 向量预筛"); - vectorResults = await vectorPreFilter( - graph, - userMessage, - activeNodes, - embeddingConfig, - normalizedTopK, - signal, - ); - } + const segments = enableMultiIntent + ? splitIntentSegments(userMessage, { + maxSegments: multiIntentMaxSegments, + }) + : []; + const queries = [userMessage, ...segments.filter((item) => item !== userMessage)]; + const groups = []; - // ========== 第 2 层:图扩散 ========== + retrievalMeta.segmentsUsed = segments; + for (const queryText of queries) { + const results = await vectorPreFilter( + graph, + queryText, + activeNodes, + embeddingConfig, + normalizedTopK, + signal, + ); + groups.push(results); + } + + const merged = mergeVectorResults( + groups, + Math.max(normalizedTopK * 2, 24), + ); + retrievalMeta.vectorHits = merged.rawHitCount; + retrievalMeta.vectorMergedHits = merged.results.length; + vectorResults = merged.results; + } else if (enableVectorPrefilter) { + pushSkipReason(retrievalMeta, "vector-config-invalid"); + } + retrievalMeta.timings.vector = roundMs(nowMs() - vectorStartedAt); + + exactEntityAnchors.push(...extractEntityAnchors(userMessage, activeNodes)); + supplementalAnchorNodeIds = collectSupplementalAnchorNodeIds( + graph, + vectorResults, + exactEntityAnchors.map((item) => item.nodeId), + 5, + ); + + let residualResult = { + triggered: false, + hits: [], + skipReason: "", + }; + const residualStartedAt = nowMs(); + if (enableResidualRecall) { + const basisNodes = buildResidualBasisNodes( + graph, + exactEntityAnchors, + vectorResults, + residualBasisMaxNodes, + ); + residualResult = await runResidualRecall({ + queryText: userMessage, + graph, + embeddingConfig, + basisNodes, + candidateNodes: activeNodes, + basisLimit: residualBasisMaxNodes, + nTopics: residualNmfTopics, + noveltyThreshold: residualNmfNoveltyThreshold, + residualThreshold, + residualTopK, + signal, + }); + retrievalMeta.residualTriggered = Boolean(residualResult.triggered); + retrievalMeta.residualHits = residualResult.hits?.length || 0; + pushSkipReason(retrievalMeta, residualResult.skipReason); + } + retrievalMeta.timings.residual = roundMs(nowMs() - residualStartedAt); + + const diffusionStartedAt = nowMs(); if (enableGraphDiffusion) { console.log("[ST-BME] 第2层: PEDSA 图扩散"); - const entityAnchors = extractEntityAnchors(userMessage, activeNodes); - const seeds = [ ...vectorResults.map((v) => ({ id: v.nodeId, energy: v.score })), - ...entityAnchors.map((a) => ({ id: a.nodeId, energy: 2.0 })), + ...exactEntityAnchors.map((item) => ({ id: item.nodeId, energy: 2.0 })), + ...(residualResult.hits || []).map((item) => ({ + id: item.nodeId, + energy: item.score, + })), ]; - // v2 ⑧: 双记忆交叉检索(AriGraph 启发) - // 实体锚点命中后,沿边展开关联的情景节点作为额外种子 - if (enableCrossRecall && entityAnchors.length > 0) { - for (const anchor of entityAnchors) { + if (enableCrossRecall && exactEntityAnchors.length > 0) { + for (const anchor of exactEntityAnchors) { const connectedEdges = getNodeEdges(graph, anchor.nodeId); for (const edge of connectedEdges) { if (edge.invalidAt) continue; @@ -192,7 +368,6 @@ export async function retrieve({ } } - // 去重种子 const seedMap = new Map(); for (const s of seeds) { const existing = seedMap.get(s.id) || 0; @@ -202,41 +377,46 @@ export async function retrieve({ id, energy, })); + retrievalMeta.seedCount = uniqueSeeds.length; if (uniqueSeeds.length > 0) { - const adjacencyMap = buildTemporalAdjacencyMap(graph); + const adjacencyMap = buildTemporalAdjacencyMap(graph, { + includeTemporalLinks: enableTemporalLinks, + temporalLinkStrength, + }); + retrievalMeta.temporalSyntheticEdgeCount = + Number(adjacencyMap.syntheticEdgeCount) || 0; + retrievalMeta.teleportAlpha = teleportAlpha; diffusionResults = diffuseAndRank(adjacencyMap, uniqueSeeds, { maxSteps: 2, decayFactor: 0.6, topK: normalizedDiffusionTopK, + teleportAlpha, }).filter((item) => { const node = getNode(graph, item.nodeId); return node && !node.archived; }); } } + retrievalMeta.diffusionHits = diffusionResults.length; + retrievalMeta.timings.diffusion = roundMs(nowMs() - diffusionStartedAt); - // ========== 第 3 层:混合评分 + 可选 LLM 精确 ========== console.log("[ST-BME] 第3层: 混合评分"); - // 构建评分表 const scoreMap = new Map(); - // 添加向量得分 for (const v of vectorResults) { const entry = scoreMap.get(v.nodeId) || { graphScore: 0, vectorScore: 0 }; entry.vectorScore = v.score; scoreMap.set(v.nodeId, entry); } - // 添加图扩散得分 for (const d of diffusionResults) { const entry = scoreMap.get(d.nodeId) || { graphScore: 0, vectorScore: 0 }; entry.graphScore = d.energy; scoreMap.set(d.nodeId, entry); } - // 两个上游阶段都未产出候选时,退回到全部活跃节点参与评分 if (scoreMap.size === 0) { for (const node of activeNodes) { if (!scoreMap.has(node.id)) { @@ -245,7 +425,60 @@ export async function retrieve({ } } - // 计算混合得分 + const cooccurrenceStartedAt = nowMs(); + if (enableCooccurrenceBoost) { + const anchorWeights = new Map(); + for (const anchor of exactEntityAnchors) { + anchorWeights.set(anchor.nodeId, 2.0); + } + for (const nodeId of supplementalAnchorNodeIds) { + const fallbackWeight = + scoreMap.get(nodeId)?.vectorScore || + scoreMap.get(nodeId)?.graphScore || + 0.5; + anchorWeights.set( + nodeId, + Math.max(anchorWeights.get(nodeId) || 0, fallbackWeight), + ); + } + + if (anchorWeights.size > 0) { + const cooccurrenceIndex = createCooccurrenceIndex(graph, { + maxAnchorsPerBatch: 10, + eligibleNodes: activeNodes.filter(isEligibleAnchorNode), + }); + const graphScores = new Map( + [...scoreMap.entries()].map(([nodeId, value]) => [ + nodeId, + value.graphScore || 0, + ]), + ); + const boosted = applyCooccurrenceBoost( + graphScores, + anchorWeights, + cooccurrenceIndex, + { + scale: cooccurrenceScale, + maxNeighbors: cooccurrenceMaxNeighbors, + }, + ); + retrievalMeta.cooccurrenceBoostedNodes = boosted.boostedNodes.length; + + for (const [nodeId, boostedScore] of boosted.scores.entries()) { + const entry = scoreMap.get(nodeId) || { graphScore: 0, vectorScore: 0 }; + entry.graphScore = boostedScore; + scoreMap.set(nodeId, entry); + } + if (boosted.boostedNodes.length === 0) { + pushSkipReason(retrievalMeta, "cooccurrence-no-neighbors"); + } + } else { + pushSkipReason(retrievalMeta, "cooccurrence-no-anchor"); + } + } + retrievalMeta.timings.cooccurrence = roundMs(nowMs() - cooccurrenceStartedAt); + + const scoringStartedAt = nowMs(); const scoredNodes = []; for (const [nodeId, scores] of scoreMap) { const node = getNode(graph, nodeId); @@ -265,22 +498,29 @@ export async function retrieve({ } scoredNodes.sort((a, b) => b.finalScore - a.finalScore); - - // 决定是否使用 LLM 精确召回 - useLLM = enableLLMRecall; + retrievalMeta.scoredCandidates = scoredNodes.length; + retrievalMeta.timings.scoring = roundMs(nowMs() - scoringStartedAt); let selectedNodeIds; + let llmCandidates = []; + const diversityStartedAt = nowMs(); + let llmDurationMs = 0; - if (useLLM && nodeCount > 0) { + if (enableLLMRecall && nodeCount > 0) { console.log("[ST-BME] LLM 精确召回"); - const candidateNodes = scoredNodes.slice( - 0, - Math.min(normalizedLlmCandidatePool, scoredNodes.length), + llmCandidates = resolveCandidatePool( + scoredNodes, + normalizedLlmCandidatePool, + dppCandidateMultiplier, + enableDiversitySampling, + dppQualityWeight, + retrievalMeta, ); + const llmStartedAt = nowMs(); const llmResult = await llmRecall( userMessage, recentMessages, - candidateNodes, + llmCandidates, graph, schema, normalizedMaxRecallNodes, @@ -289,18 +529,25 @@ export async function retrieve({ signal, onStreamProgress, ); + llmDurationMs = nowMs() - llmStartedAt; selectedNodeIds = llmResult.selectedNodeIds; llmMeta = { enabled: true, status: llmResult.status, reason: llmResult.reason, - candidatePool: candidateNodes.length, + candidatePool: llmCandidates.length, selectedSeedCount: llmResult.selectedNodeIds.length, }; } else { - selectedNodeIds = scoredNodes - .slice(0, Math.min(normalizedTopK, scoredNodes.length)) - .map((s) => s.nodeId); + const selectedCandidates = resolveCandidatePool( + scoredNodes, + normalizedTopK, + dppCandidateMultiplier, + enableDiversitySampling, + dppQualityWeight, + retrievalMeta, + ); + selectedNodeIds = selectedCandidates.map((item) => item.nodeId); llmMeta = { enabled: false, status: "disabled", @@ -309,6 +556,8 @@ export async function retrieve({ selectedSeedCount: selectedNodeIds.length, }; } + retrievalMeta.timings.diversity = roundMs(nowMs() - diversityStartedAt); + retrievalMeta.timings.llm = roundMs(llmDurationMs); selectedNodeIds = reconstructSceneNodeIds( graph, @@ -325,8 +574,6 @@ export async function retrieve({ console.log(`[ST-BME] 检索完成: 选中 ${selectedNodeIds.length} 个节点`); - // v2 ⑧: 概率触发回忆 - // 未被选中的高重要性节点有概率随机激活 if (enableProbRecall && probRecallChance > 0) { const selectedSet = new Set(selectedNodeIds); const probability = Math.max(0.01, Math.min(0.5, probRecallChance)); @@ -351,14 +598,11 @@ export async function retrieve({ } selectedNodeIds = uniqueNodeIds(selectedNodeIds); + retrievalMeta.llm = llmMeta; + retrievalMeta.timings.total = roundMs(nowMs() - startedAt); return buildResult(graph, selectedNodeIds, schema, { - retrieval: { - vectorHits: vectorResults.length, - diffusionHits: diffusionResults.length, - scoredCandidates: scoredNodes.length, - llm: llmMeta, - }, + retrieval: retrievalMeta, }); } @@ -418,6 +662,84 @@ function extractEntityAnchors(userMessage, activeNodes) { return anchors; } +function buildResidualBasisNodes( + graph, + exactEntityAnchors, + vectorResults, + maxNodes = 24, +) { + const basis = []; + const seen = new Set(); + + for (const anchor of exactEntityAnchors || []) { + const node = getNode(graph, anchor?.nodeId); + if ( + !node || + seen.has(node.id) || + !Array.isArray(node.embedding) || + node.embedding.length === 0 + ) { + continue; + } + seen.add(node.id); + basis.push(node); + if (basis.length >= maxNodes) return basis; + } + + for (const result of vectorResults || []) { + const node = getNode(graph, result?.nodeId); + if ( + !isEligibleAnchorNode(node) || + seen.has(node?.id) || + !Array.isArray(node?.embedding) || + node.embedding.length === 0 + ) { + continue; + } + seen.add(node.id); + basis.push(node); + if (basis.length >= maxNodes) break; + } + + return basis; +} + +function resolveCandidatePool( + scoredNodes, + targetCount, + multiplier, + enableDiversitySampling, + qualityWeight, + retrievalMeta, +) { + const safeTarget = Math.max(1, targetCount); + const fallback = scoredNodes.slice(0, Math.min(safeTarget, scoredNodes.length)); + retrievalMeta.candidatePoolBeforeDpp = fallback.length; + retrievalMeta.candidatePoolAfterDpp = fallback.length; + retrievalMeta.diversityApplied = false; + + if (!enableDiversitySampling) { + return fallback; + } + + const poolLimit = Math.min( + scoredNodes.length, + Math.max(safeTarget, safeTarget * Math.max(1, multiplier)), + ); + const pool = scoredNodes.slice(0, poolLimit); + retrievalMeta.candidatePoolBeforeDpp = pool.length; + + const diversity = applyDiversitySampling(pool, { + k: safeTarget, + qualityWeight, + }); + retrievalMeta.candidatePoolAfterDpp = diversity.afterCount; + retrievalMeta.diversityApplied = diversity.applied; + pushSkipReason(retrievalMeta, diversity.reason); + + return diversity.applied ? diversity.selected : fallback; +} + /** * LLM 精确召回 */ diff --git a/tests/default-settings.mjs b/tests/default-settings.mjs index b8c99ea..848a6b4 100644 --- a/tests/default-settings.mjs +++ b/tests/default-settings.mjs @@ -44,6 +44,23 @@ assert.equal(defaultSettings.recallEnableGraphDiffusion, true); assert.equal(defaultSettings.recallDiffusionTopK, 100); assert.equal(defaultSettings.recallLlmCandidatePool, 30); assert.equal(defaultSettings.recallLlmContextMessages, 4); +assert.equal(defaultSettings.recallEnableMultiIntent, true); +assert.equal(defaultSettings.recallMultiIntentMaxSegments, 4); +assert.equal(defaultSettings.recallTeleportAlpha, 0.15); +assert.equal(defaultSettings.recallEnableTemporalLinks, true); +assert.equal(defaultSettings.recallTemporalLinkStrength, 0.2); +assert.equal(defaultSettings.recallEnableDiversitySampling, true); +assert.equal(defaultSettings.recallDppCandidateMultiplier, 3); +assert.equal(defaultSettings.recallDppQualityWeight, 1.0); +assert.equal(defaultSettings.recallEnableCooccurrenceBoost, false); +assert.equal(defaultSettings.recallCooccurrenceScale, 0.1); +assert.equal(defaultSettings.recallCooccurrenceMaxNeighbors, 10); +assert.equal(defaultSettings.recallEnableResidualRecall, false); +assert.equal(defaultSettings.recallResidualBasisMaxNodes, 24); +assert.equal(defaultSettings.recallNmfTopics, 15); +assert.equal(defaultSettings.recallNmfNoveltyThreshold, 0.4); +assert.equal(defaultSettings.recallResidualThreshold, 0.3); +assert.equal(defaultSettings.recallResidualTopK, 5); assert.equal(defaultSettings.injectDepth, 9999); assert.equal(defaultSettings.enableReflection, true); assert.equal(defaultSettings.embeddingTransportMode, "direct"); diff --git a/tests/graph-retrieval.mjs b/tests/graph-retrieval.mjs index 8db64f1..6dc04b7 100644 --- a/tests/graph-retrieval.mjs +++ b/tests/graph-retrieval.mjs @@ -61,16 +61,30 @@ const replacementEdge = createEdge({ assert.ok(addEdge(graph, replacementEdge)); assert.notEqual(replacementEdge.id, historicalEdge.id); -const adjacencyMap = buildTemporalAdjacencyMap(graph); +const adjacencyMap = buildTemporalAdjacencyMap(graph, { + includeTemporalLinks: true, + temporalLinkStrength: 0.2, +}); const event1Neighbors = adjacencyMap.get(event1.id) || []; -assert.equal(event1Neighbors.length, 1); -assert.equal(event1Neighbors[0].targetId, character.id); -assert.equal(event1Neighbors[0].strength, 0.7); +assert.equal(adjacencyMap.syntheticEdgeCount, 1); +assert.ok( + event1Neighbors.some( + (item) => item.targetId === character.id && item.strength === 0.7, + ), +); +assert.ok( + event1Neighbors.some( + (item) => item.targetId === event2.id && item.strength === 0.2, + ), +); const diffusion = diffuseAndRank(adjacencyMap, [ { id: event2.id, energy: 1 }, { id: event2.id, energy: 0.5 }, -]); +], { + teleportAlpha: 0.15, +}); assert.ok(diffusion.some((item) => item.nodeId === character.id)); +assert.ok(diffusion.some((item) => item.nodeId === event1.id)); console.log("graph-retrieval tests passed"); diff --git a/tests/retrieval-config.mjs b/tests/retrieval-config.mjs index 160acf7..baecf2a 100644 --- a/tests/retrieval-config.mjs +++ b/tests/retrieval-config.mjs @@ -96,14 +96,61 @@ const retrieve = await loadRetrieve({ applyTaskRegex(_settings, _taskType, _stage, text) { return text; }, + splitIntentSegments(text) { + if (String(text).includes("和")) { + return String(text).split("和").map((item) => item.trim()); + } + return []; + }, + mergeVectorResults(groups, limit) { + const merged = new Map(); + let rawHitCount = 0; + for (const group of groups) { + for (const item of group) { + rawHitCount += 1; + const existing = merged.get(item.nodeId); + if (!existing || item.score > existing.score) { + merged.set(item.nodeId, item); + } + } + } + return { + rawHitCount, + results: [...merged.values()].slice(0, limit), + }; + }, + collectSupplementalAnchorNodeIds() { + return []; + }, + isEligibleAnchorNode(node) { + return Boolean(node?.fields?.title || node?.fields?.name); + }, + createCooccurrenceIndex() { + return { map: new Map(), source: "batchJournal", batchCount: 0, pairCount: 0 }; + }, + applyCooccurrenceBoost(baseScores) { + return { scores: new Map(baseScores), boostedNodes: [] }; + }, + applyDiversitySampling(candidates, { k }) { + return { + applied: true, + reason: "", + selected: candidates.slice(0, k).reverse(), + beforeCount: candidates.length, + afterCount: Math.min(k, candidates.length), + }; + }, + async runResidualRecall() { + return { triggered: false, hits: [], skipReason: "residual-disabled-test" }; + }, hybridScore: ({ graphScore = 0, vectorScore = 0, importance = 0 }) => graphScore + vectorScore + importance, reinforceAccessBatch() {}, validateVectorConfig() { return { valid: true }; }, - async findSimilarNodesByText(_graph, _message, _embeddingConfig, topK) { - state.vectorCalls.push(topK); + async findSimilarNodesByText(_graph, message, _embeddingConfig, topK) { + state.vectorCalls.push({ topK, message }); return [ { nodeId: "rule-1", score: 0.9 }, { nodeId: "rule-2", score: 0.8 }, @@ -124,8 +171,8 @@ const retrieve = await loadRetrieve({ .filter((line) => line.trim().startsWith("[")).length; return { selected_ids: ["rule-2", "rule-1"] }; }, - getSTContextForPrompt() { - return {}; + getSTContextForPrompt() { + return {}; }, }); @@ -149,7 +196,7 @@ const noStageResult = await retrieve({ assert.equal(state.vectorCalls.length, 0); assert.equal(state.diffusionCalls.length, 0); assert.equal(state.llmCalls.length, 0); -assert.deepEqual(Array.from(noStageResult.selectedNodeIds), ["rule-1", "rule-2"]); +assert.deepEqual(Array.from(noStageResult.selectedNodeIds), ["rule-2", "rule-1"]); state.vectorCalls.length = 0; state.diffusionCalls.length = 0; @@ -170,12 +217,16 @@ const llmPoolResult = await retrieve({ llmCandidatePool: 2, }, }); -assert.deepEqual(state.vectorCalls, [4]); +assert.deepEqual(state.vectorCalls, [{ topK: 4, message: "请根据规则给出结论" }]); assert.equal(state.diffusionCalls.length, 0); assert.equal(state.llmCandidateCount, 2); assert.deepEqual(Array.from(llmPoolResult.selectedNodeIds), ["rule-2", "rule-1"]); assert.equal(llmPoolResult.meta.retrieval.llm.status, "llm"); assert.equal(llmPoolResult.meta.retrieval.llm.candidatePool, 2); +assert.equal(llmPoolResult.meta.retrieval.vectorMergedHits, 3); +assert.equal(llmPoolResult.meta.retrieval.diversityApplied, true); +assert.equal(llmPoolResult.meta.retrieval.candidatePoolBeforeDpp, 3); +assert.equal(llmPoolResult.meta.retrieval.candidatePoolAfterDpp, 2); state.vectorCalls.length = 0; state.diffusionCalls.length = 0; @@ -193,11 +244,21 @@ await retrieve({ enableGraphDiffusion: true, diffusionTopK: 7, enableLLMRecall: false, + enableMultiIntent: true, + multiIntentMaxSegments: 4, + enableTemporalLinks: true, + temporalLinkStrength: 0.2, + teleportAlpha: 0.15, }, }); -assert.deepEqual(state.vectorCalls, [3]); +assert.equal(state.vectorCalls.length, 3); +assert.deepEqual( + state.vectorCalls.map((item) => item.topK), + [3, 3, 3], +); assert.equal(state.diffusionCalls.length, 1); assert.equal(state.diffusionCalls[0].options.topK, 7); +assert.equal(state.diffusionCalls[0].options.teleportAlpha, 0.15); assert.equal(noStageResult.meta.retrieval.llm.status, "disabled"); console.log("retrieval-config tests passed"); diff --git a/tests/retrieval-enhancer.mjs b/tests/retrieval-enhancer.mjs new file mode 100644 index 0000000..5b5e54a --- /dev/null +++ b/tests/retrieval-enhancer.mjs @@ -0,0 +1,154 @@ +import assert from "node:assert/strict"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { fileURLToPath } from "node:url"; +import vm from "node:vm"; +import { addNode, createEmptyGraph, createNode } from "../graph.js"; + +async function loadEnhancer() { + const __dirname = path.dirname(fileURLToPath(import.meta.url)); + const enhancerPath = path.resolve(__dirname, "../retrieval-enhancer.js"); + const source = await fs.readFile(enhancerPath, "utf8"); + const transformed = `${source + .replace(/^import[\s\S]*?from\s+["'][^"']+["'];\r?\n/gm, "") + .replace(/export function /g, "function ") + .replace(/export async function /g, "async function ")} +this.exports = { + applyDiversitySampling, + createCooccurrenceIndex, + nmfQueryAnalysis, + sparseCodeResidual, + splitIntentSegments, +}; +`; + + const context = vm.createContext({ + Math, + Date, + console, + WeakMap, + Map, + Set, + Array, + Number, + String, + JSON, + embedText: async () => null, + searchSimilar: () => [], + getNode(graph, nodeId) { + return graph.nodes.find((node) => node.id === nodeId) || null; + }, + isDirectVectorConfig() { + return true; + }, + }); + new vm.Script(transformed).runInContext(context); + return context.exports; +} + +const { + applyDiversitySampling, + createCooccurrenceIndex, + nmfQueryAnalysis, + sparseCodeResidual, + splitIntentSegments, +} = await loadEnhancer(); + +const segments = splitIntentSegments("规则一,然后规则二。另外规则三", { + maxSegments: 4, +}); +assert.deepEqual(Array.from(segments), ["规则一", "规则二", "规则三"]); + +const diversity = applyDiversitySampling( + [ + { + nodeId: "a", + finalScore: 0.95, + node: { embedding: [1, 0, 0] }, + }, + { + nodeId: "b", + finalScore: 0.9, + node: { embedding: [0.99, 0.01, 0] }, + }, + { + nodeId: "c", + finalScore: 0.85, + node: { embedding: [0, 1, 0] }, + }, + ], + { k: 2, qualityWeight: 1.0 }, +); +assert.equal(diversity.applied, true); +assert.equal(diversity.selected.length, 2); +assert.ok(Array.from(diversity.selected).some((item) => item.nodeId === "a")); +assert.ok(Array.from(diversity.selected).some((item) => item.nodeId === "c")); + +const graph = createEmptyGraph(); +const ruleA = createNode({ + type: "rule", + seq: 1, + seqRange: [1, 2], + fields: { title: "规则A" }, +}); +const ruleB = createNode({ + type: "rule", + seq: 2, + seqRange: [2, 3], + fields: { title: "规则B" }, +}); +const location = createNode({ + type: "location", + seq: 2, + seqRange: [2, 2], + fields: { name: "酒馆" }, +}); +addNode(graph, ruleA); +addNode(graph, ruleB); +addNode(graph, location); +graph.batchJournal = [{ processedRange: [2, 2] }]; + +const cooccurrence = createCooccurrenceIndex(graph, { + eligibleNodes: graph.nodes, + maxAnchorsPerBatch: 10, +}); +assert.equal(cooccurrence.source, "batchJournal"); +assert.equal(cooccurrence.batchCount, 1); +assert.ok( + (cooccurrence.map.get(ruleA.id) || []).some((item) => item.nodeId === ruleB.id), +); + +graph.batchJournal = []; +const fallbackCooccurrence = createCooccurrenceIndex(graph, { + eligibleNodes: graph.nodes, + maxAnchorsPerBatch: 10, +}); +assert.equal(fallbackCooccurrence.source, "seqRange"); +assert.ok( + (fallbackCooccurrence.map.get(ruleA.id) || []).some( + (item) => item.nodeId === ruleB.id, + ), +); + +const nmf = nmfQueryAnalysis( + [0.8, 0.6, 0, 0], + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + ], + { nTopics: 2, maxIter: 50 }, +); +assert.ok(nmf.semanticDepth >= 0); +assert.ok(nmf.novelty >= 0); + +const sparse = sparseCodeResidual( + [0.8, 0.6, 0, 0], + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + ], + { lambda: 0.01, maxIter: 100 }, +); +assert.ok(sparse.residualNorm < 0.2); + +console.log("retrieval-enhancer tests passed");