feat: enhance recall pipeline retrieval stack

This commit is contained in:
Youzini-afk
2026-03-27 19:43:40 +08:00
parent 27aad180d3
commit 84211d9b9d
11 changed files with 1943 additions and 96 deletions

View File

@@ -16,6 +16,16 @@ import {
buildTaskLlmPayload,
buildTaskPrompt,
} from "./prompt-builder.js";
import {
applyCooccurrenceBoost,
applyDiversitySampling,
collectSupplementalAnchorNodeIds,
createCooccurrenceIndex,
isEligibleAnchorNode,
mergeVectorResults,
runResidualRecall,
splitIntentSegments,
} from "./retrieval-enhancer.js";
import { applyTaskRegex } from "./task-regex.js";
import { getSTContextForPrompt } from "./st-context.js";
import { findSimilarNodesByText, validateVectorConfig } from "./vector-index.js";
@@ -59,6 +69,65 @@ function throwIfAborted(signal) {
}
}
function nowMs() {
return typeof performance !== "undefined" && performance?.now
? performance.now()
: Date.now();
}
function roundMs(value) {
return Math.round((Number(value) || 0) * 10) / 10;
}
function pushSkipReason(meta, reason) {
if (!reason) return;
if (!Array.isArray(meta.skipReasons)) {
meta.skipReasons = [];
}
if (!meta.skipReasons.includes(reason)) {
meta.skipReasons.push(reason);
}
}
function createRetrievalMeta(enableLLMRecall) {
return {
vectorHits: 0,
diffusionHits: 0,
scoredCandidates: 0,
segmentsUsed: [],
vectorMergedHits: 0,
seedCount: 0,
temporalSyntheticEdgeCount: 0,
teleportAlpha: 0,
cooccurrenceBoostedNodes: 0,
candidatePoolBeforeDpp: 0,
candidatePoolAfterDpp: 0,
diversityApplied: false,
residualTriggered: false,
residualHits: 0,
skipReasons: [],
timings: {},
llm: {
enabled: enableLLMRecall,
status: enableLLMRecall ? "pending" : "disabled",
reason: enableLLMRecall ? "" : "LLM 精排已关闭",
candidatePool: 0,
selectedSeedCount: 0,
},
};
}
function clampPositiveInt(value, fallback, min = 1) {
const parsed = Math.floor(Number(value));
return Number.isFinite(parsed) && parsed >= min ? parsed : fallback;
}
function clampRange(value, fallback, min = 0, max = 1) {
const parsed = Number(value);
if (!Number.isFinite(parsed)) return fallback;
return Math.max(min, Math.min(max, parsed));
}
/**
* 三层混合检索管线
*
@@ -83,21 +152,71 @@ export async function retrieve({
onStreamProgress = null,
}) {
throwIfAborted(signal);
const topK = options.topK ?? 20;
const maxRecallNodes = options.maxRecallNodes ?? 8;
const startedAt = nowMs();
const topK = clampPositiveInt(options.topK, 20);
const maxRecallNodes = clampPositiveInt(options.maxRecallNodes, 8);
const enableLLMRecall = options.enableLLMRecall ?? true;
const enableVectorPrefilter = options.enableVectorPrefilter ?? true;
const enableGraphDiffusion = options.enableGraphDiffusion ?? true;
const diffusionTopK = options.diffusionTopK ?? 100;
const llmCandidatePool = options.llmCandidatePool ?? 30;
const diffusionTopK = clampPositiveInt(options.diffusionTopK, 100);
const llmCandidatePool = clampPositiveInt(options.llmCandidatePool, 30);
const weights = options.weights ?? {};
// v2 options
const enableVisibility = options.enableVisibility ?? false;
const visibilityFilter = options.visibilityFilter ?? null;
const enableCrossRecall = options.enableCrossRecall ?? false;
const enableProbRecall = options.enableProbRecall ?? false;
const probRecallChance = options.probRecallChance ?? 0.15;
const enableMultiIntent = options.enableMultiIntent ?? true;
const multiIntentMaxSegments = clampPositiveInt(
options.multiIntentMaxSegments,
4,
);
const teleportAlpha = clampRange(options.teleportAlpha, 0.15);
const enableTemporalLinks = options.enableTemporalLinks ?? true;
const temporalLinkStrength = clampRange(
options.temporalLinkStrength,
0.2,
);
const enableDiversitySampling = options.enableDiversitySampling ?? true;
const dppCandidateMultiplier = clampPositiveInt(
options.dppCandidateMultiplier,
3,
);
const dppQualityWeight = clampRange(
options.dppQualityWeight,
1.0,
0,
10,
);
const enableCooccurrenceBoost = options.enableCooccurrenceBoost ?? false;
const cooccurrenceScale = clampRange(
options.cooccurrenceScale,
0.1,
0,
10,
);
const cooccurrenceMaxNeighbors = clampPositiveInt(
options.cooccurrenceMaxNeighbors,
10,
);
const enableResidualRecall = options.enableResidualRecall ?? false;
const residualBasisMaxNodes = clampPositiveInt(
options.residualBasisMaxNodes,
24,
2,
);
const residualNmfTopics = clampPositiveInt(options.residualNmfTopics, 15);
const residualNmfNoveltyThreshold = clampRange(
options.residualNmfNoveltyThreshold,
0.4,
);
const residualThreshold = clampRange(
options.residualThreshold,
0.3,
0,
10,
);
const residualTopK = clampPositiveInt(options.residualTopK, 5);
let activeNodes = getActiveNodes(graph).filter(
(node) =>
@@ -106,7 +225,6 @@ export async function retrieve({
Number.isFinite(node.seqRange[1]),
);
// v2 ⑦: 认知边界过滤RoleRAG 启发)
if (enableVisibility && visibilityFilter) {
activeNodes = filterByVisibility(activeNodes, visibilityFilter);
}
@@ -119,66 +237,124 @@ export async function retrieve({
normalizedMaxRecallNodes,
llmCandidatePool,
);
const vectorValidation = validateVectorConfig(embeddingConfig);
const retrievalMeta = createRetrievalMeta(enableLLMRecall);
console.log(
`[ST-BME] 检索开始: ${nodeCount} 个活跃节点${enableVisibility ? " (认知边界已启用)" : ""}`,
);
let vectorResults = [];
let diffusionResults = [];
let useLLM = false;
let llmMeta = {
enabled: enableLLMRecall,
status: enableLLMRecall ? "pending" : "disabled",
reason: enableLLMRecall ? "" : "LLM 精排已关闭",
candidatePool: 0,
selectedSeedCount: 0,
};
let llmMeta = { ...retrievalMeta.llm };
const exactEntityAnchors = [];
let supplementalAnchorNodeIds = [];
if (nodeCount === 0) {
return buildResult(graph, [], schema, {
retrieval: {
vectorHits: 0,
diffusionHits: 0,
scoredCandidates: 0,
...retrievalMeta,
llm: {
...llmMeta,
status: enableLLMRecall ? "skipped" : "disabled",
reason: "当前没有可参与召回的活跃节点",
},
timings: {
total: roundMs(nowMs() - startedAt),
},
},
});
}
// ========== 第 1 层:向量预筛 ==========
if (
enableVectorPrefilter &&
validateVectorConfig(embeddingConfig).valid
) {
const vectorStartedAt = nowMs();
if (enableVectorPrefilter && vectorValidation.valid) {
console.log("[ST-BME] 第1层: 向量预筛");
vectorResults = await vectorPreFilter(
graph,
userMessage,
activeNodes,
embeddingConfig,
normalizedTopK,
signal,
);
}
const segments = enableMultiIntent
? splitIntentSegments(userMessage, {
maxSegments: multiIntentMaxSegments,
})
: [];
const queries = [userMessage, ...segments.filter((item) => item !== userMessage)];
const groups = [];
// ========== 第 2 层:图扩散 ==========
retrievalMeta.segmentsUsed = segments;
for (const queryText of queries) {
const results = await vectorPreFilter(
graph,
queryText,
activeNodes,
embeddingConfig,
normalizedTopK,
signal,
);
groups.push(results);
}
const merged = mergeVectorResults(
groups,
Math.max(normalizedTopK * 2, 24),
);
retrievalMeta.vectorHits = merged.rawHitCount;
retrievalMeta.vectorMergedHits = merged.results.length;
vectorResults = merged.results;
} else if (enableVectorPrefilter) {
pushSkipReason(retrievalMeta, "vector-config-invalid");
}
retrievalMeta.timings.vector = roundMs(nowMs() - vectorStartedAt);
exactEntityAnchors.push(...extractEntityAnchors(userMessage, activeNodes));
supplementalAnchorNodeIds = collectSupplementalAnchorNodeIds(
graph,
vectorResults,
exactEntityAnchors.map((item) => item.nodeId),
5,
);
let residualResult = {
triggered: false,
hits: [],
skipReason: "",
};
const residualStartedAt = nowMs();
if (enableResidualRecall) {
const basisNodes = buildResidualBasisNodes(
graph,
exactEntityAnchors,
vectorResults,
residualBasisMaxNodes,
);
residualResult = await runResidualRecall({
queryText: userMessage,
graph,
embeddingConfig,
basisNodes,
candidateNodes: activeNodes,
basisLimit: residualBasisMaxNodes,
nTopics: residualNmfTopics,
noveltyThreshold: residualNmfNoveltyThreshold,
residualThreshold,
residualTopK,
signal,
});
retrievalMeta.residualTriggered = Boolean(residualResult.triggered);
retrievalMeta.residualHits = residualResult.hits?.length || 0;
pushSkipReason(retrievalMeta, residualResult.skipReason);
}
retrievalMeta.timings.residual = roundMs(nowMs() - residualStartedAt);
const diffusionStartedAt = nowMs();
if (enableGraphDiffusion) {
console.log("[ST-BME] 第2层: PEDSA 图扩散");
const entityAnchors = extractEntityAnchors(userMessage, activeNodes);
const seeds = [
...vectorResults.map((v) => ({ id: v.nodeId, energy: v.score })),
...entityAnchors.map((a) => ({ id: a.nodeId, energy: 2.0 })),
...exactEntityAnchors.map((item) => ({ id: item.nodeId, energy: 2.0 })),
...(residualResult.hits || []).map((item) => ({
id: item.nodeId,
energy: item.score,
})),
];
// v2 ⑧: 双记忆交叉检索AriGraph 启发)
// 实体锚点命中后,沿边展开关联的情景节点作为额外种子
if (enableCrossRecall && entityAnchors.length > 0) {
for (const anchor of entityAnchors) {
if (enableCrossRecall && exactEntityAnchors.length > 0) {
for (const anchor of exactEntityAnchors) {
const connectedEdges = getNodeEdges(graph, anchor.nodeId);
for (const edge of connectedEdges) {
if (edge.invalidAt) continue;
@@ -192,7 +368,6 @@ export async function retrieve({
}
}
// 去重种子
const seedMap = new Map();
for (const s of seeds) {
const existing = seedMap.get(s.id) || 0;
@@ -202,41 +377,46 @@ export async function retrieve({
id,
energy,
}));
retrievalMeta.seedCount = uniqueSeeds.length;
if (uniqueSeeds.length > 0) {
const adjacencyMap = buildTemporalAdjacencyMap(graph);
const adjacencyMap = buildTemporalAdjacencyMap(graph, {
includeTemporalLinks: enableTemporalLinks,
temporalLinkStrength,
});
retrievalMeta.temporalSyntheticEdgeCount =
Number(adjacencyMap.syntheticEdgeCount) || 0;
retrievalMeta.teleportAlpha = teleportAlpha;
diffusionResults = diffuseAndRank(adjacencyMap, uniqueSeeds, {
maxSteps: 2,
decayFactor: 0.6,
topK: normalizedDiffusionTopK,
teleportAlpha,
}).filter((item) => {
const node = getNode(graph, item.nodeId);
return node && !node.archived;
});
}
}
retrievalMeta.diffusionHits = diffusionResults.length;
retrievalMeta.timings.diffusion = roundMs(nowMs() - diffusionStartedAt);
// ========== 第 3 层:混合评分 + 可选 LLM 精确 ==========
console.log("[ST-BME] 第3层: 混合评分");
// 构建评分表
const scoreMap = new Map();
// 添加向量得分
for (const v of vectorResults) {
const entry = scoreMap.get(v.nodeId) || { graphScore: 0, vectorScore: 0 };
entry.vectorScore = v.score;
scoreMap.set(v.nodeId, entry);
}
// 添加图扩散得分
for (const d of diffusionResults) {
const entry = scoreMap.get(d.nodeId) || { graphScore: 0, vectorScore: 0 };
entry.graphScore = d.energy;
scoreMap.set(d.nodeId, entry);
}
// 两个上游阶段都未产出候选时,退回到全部活跃节点参与评分
if (scoreMap.size === 0) {
for (const node of activeNodes) {
if (!scoreMap.has(node.id)) {
@@ -245,7 +425,60 @@ export async function retrieve({
}
}
// 计算混合得分
const cooccurrenceStartedAt = nowMs();
if (enableCooccurrenceBoost) {
const anchorWeights = new Map();
for (const anchor of exactEntityAnchors) {
anchorWeights.set(anchor.nodeId, 2.0);
}
for (const nodeId of supplementalAnchorNodeIds) {
const fallbackWeight =
scoreMap.get(nodeId)?.vectorScore ||
scoreMap.get(nodeId)?.graphScore ||
0.5;
anchorWeights.set(
nodeId,
Math.max(anchorWeights.get(nodeId) || 0, fallbackWeight),
);
}
if (anchorWeights.size > 0) {
const cooccurrenceIndex = createCooccurrenceIndex(graph, {
maxAnchorsPerBatch: 10,
eligibleNodes: activeNodes.filter(isEligibleAnchorNode),
});
const graphScores = new Map(
[...scoreMap.entries()].map(([nodeId, value]) => [
nodeId,
value.graphScore || 0,
]),
);
const boosted = applyCooccurrenceBoost(
graphScores,
anchorWeights,
cooccurrenceIndex,
{
scale: cooccurrenceScale,
maxNeighbors: cooccurrenceMaxNeighbors,
},
);
retrievalMeta.cooccurrenceBoostedNodes = boosted.boostedNodes.length;
for (const [nodeId, boostedScore] of boosted.scores.entries()) {
const entry = scoreMap.get(nodeId) || { graphScore: 0, vectorScore: 0 };
entry.graphScore = boostedScore;
scoreMap.set(nodeId, entry);
}
if (boosted.boostedNodes.length === 0) {
pushSkipReason(retrievalMeta, "cooccurrence-no-neighbors");
}
} else {
pushSkipReason(retrievalMeta, "cooccurrence-no-anchor");
}
}
retrievalMeta.timings.cooccurrence = roundMs(nowMs() - cooccurrenceStartedAt);
const scoringStartedAt = nowMs();
const scoredNodes = [];
for (const [nodeId, scores] of scoreMap) {
const node = getNode(graph, nodeId);
@@ -265,22 +498,29 @@ export async function retrieve({
}
scoredNodes.sort((a, b) => b.finalScore - a.finalScore);
// 决定是否使用 LLM 精确召回
useLLM = enableLLMRecall;
retrievalMeta.scoredCandidates = scoredNodes.length;
retrievalMeta.timings.scoring = roundMs(nowMs() - scoringStartedAt);
let selectedNodeIds;
let llmCandidates = [];
const diversityStartedAt = nowMs();
let llmDurationMs = 0;
if (useLLM && nodeCount > 0) {
if (enableLLMRecall && nodeCount > 0) {
console.log("[ST-BME] LLM 精确召回");
const candidateNodes = scoredNodes.slice(
0,
Math.min(normalizedLlmCandidatePool, scoredNodes.length),
llmCandidates = resolveCandidatePool(
scoredNodes,
normalizedLlmCandidatePool,
dppCandidateMultiplier,
enableDiversitySampling,
dppQualityWeight,
retrievalMeta,
);
const llmStartedAt = nowMs();
const llmResult = await llmRecall(
userMessage,
recentMessages,
candidateNodes,
llmCandidates,
graph,
schema,
normalizedMaxRecallNodes,
@@ -289,18 +529,25 @@ export async function retrieve({
signal,
onStreamProgress,
);
llmDurationMs = nowMs() - llmStartedAt;
selectedNodeIds = llmResult.selectedNodeIds;
llmMeta = {
enabled: true,
status: llmResult.status,
reason: llmResult.reason,
candidatePool: candidateNodes.length,
candidatePool: llmCandidates.length,
selectedSeedCount: llmResult.selectedNodeIds.length,
};
} else {
selectedNodeIds = scoredNodes
.slice(0, Math.min(normalizedTopK, scoredNodes.length))
.map((s) => s.nodeId);
const selectedCandidates = resolveCandidatePool(
scoredNodes,
normalizedTopK,
dppCandidateMultiplier,
enableDiversitySampling,
dppQualityWeight,
retrievalMeta,
);
selectedNodeIds = selectedCandidates.map((item) => item.nodeId);
llmMeta = {
enabled: false,
status: "disabled",
@@ -309,6 +556,8 @@ export async function retrieve({
selectedSeedCount: selectedNodeIds.length,
};
}
retrievalMeta.timings.diversity = roundMs(nowMs() - diversityStartedAt);
retrievalMeta.timings.llm = roundMs(llmDurationMs);
selectedNodeIds = reconstructSceneNodeIds(
graph,
@@ -325,8 +574,6 @@ export async function retrieve({
console.log(`[ST-BME] 检索完成: 选中 ${selectedNodeIds.length} 个节点`);
// v2 ⑧: 概率触发回忆
// 未被选中的高重要性节点有概率随机激活
if (enableProbRecall && probRecallChance > 0) {
const selectedSet = new Set(selectedNodeIds);
const probability = Math.max(0.01, Math.min(0.5, probRecallChance));
@@ -351,14 +598,11 @@ export async function retrieve({
}
selectedNodeIds = uniqueNodeIds(selectedNodeIds);
retrievalMeta.llm = llmMeta;
retrievalMeta.timings.total = roundMs(nowMs() - startedAt);
return buildResult(graph, selectedNodeIds, schema, {
retrieval: {
vectorHits: vectorResults.length,
diffusionHits: diffusionResults.length,
scoredCandidates: scoredNodes.length,
llm: llmMeta,
},
retrieval: retrievalMeta,
});
}
@@ -418,6 +662,84 @@ function extractEntityAnchors(userMessage, activeNodes) {
return anchors;
}
function buildResidualBasisNodes(
graph,
exactEntityAnchors,
vectorResults,
maxNodes = 24,
) {
const basis = [];
const seen = new Set();
for (const anchor of exactEntityAnchors || []) {
const node = getNode(graph, anchor?.nodeId);
if (
!node ||
seen.has(node.id) ||
!Array.isArray(node.embedding) ||
node.embedding.length === 0
) {
continue;
}
seen.add(node.id);
basis.push(node);
if (basis.length >= maxNodes) return basis;
}
for (const result of vectorResults || []) {
const node = getNode(graph, result?.nodeId);
if (
!isEligibleAnchorNode(node) ||
seen.has(node?.id) ||
!Array.isArray(node?.embedding) ||
node.embedding.length === 0
) {
continue;
}
seen.add(node.id);
basis.push(node);
if (basis.length >= maxNodes) break;
}
return basis;
}
function resolveCandidatePool(
scoredNodes,
targetCount,
multiplier,
enableDiversitySampling,
qualityWeight,
retrievalMeta,
) {
const safeTarget = Math.max(1, targetCount);
const fallback = scoredNodes.slice(0, Math.min(safeTarget, scoredNodes.length));
retrievalMeta.candidatePoolBeforeDpp = fallback.length;
retrievalMeta.candidatePoolAfterDpp = fallback.length;
retrievalMeta.diversityApplied = false;
if (!enableDiversitySampling) {
return fallback;
}
const poolLimit = Math.min(
scoredNodes.length,
Math.max(safeTarget, safeTarget * Math.max(1, multiplier)),
);
const pool = scoredNodes.slice(0, poolLimit);
retrievalMeta.candidatePoolBeforeDpp = pool.length;
const diversity = applyDiversitySampling(pool, {
k: safeTarget,
qualityWeight,
});
retrievalMeta.candidatePoolAfterDpp = diversity.afterCount;
retrievalMeta.diversityApplied = diversity.applied;
pushSkipReason(retrievalMeta, diversity.reason);
return diversity.applied ? diversity.selected : fallback;
}
/**
* LLM 精确召回
*/