feat: improve retrieval recall and maintenance undo

This commit is contained in:
Youzini-afk
2026-04-01 22:37:29 +08:00
parent 1dc87245a7
commit 6f8554e11a
10 changed files with 1550 additions and 63 deletions

View File

@@ -126,10 +126,15 @@ function createRetrievalMeta(enableLLMRecall) {
diffusionHits: 0,
scoredCandidates: 0,
segmentsUsed: [],
queryBlendActive: false,
queryBlendParts: [],
queryBlendWeights: {},
vectorMergedHits: 0,
seedCount: 0,
temporalSyntheticEdgeCount: 0,
teleportAlpha: 0,
lexicalBoostedNodes: 0,
lexicalTopHits: [],
cooccurrenceBoostedNodes: 0,
candidatePoolBeforeDpp: 0,
candidatePoolAfterDpp: 0,
@@ -159,6 +164,421 @@ function clampRange(value, fallback, min = 0, max = 1) {
return Math.max(min, Math.min(max, parsed));
}
function normalizeQueryText(value, maxLength = 400) {
const normalized = String(value ?? "")
.replace(/\r\n/g, "\n")
.replace(/\s+/g, " ")
.trim();
if (!normalized) return "";
return normalized.slice(0, Math.max(1, maxLength));
}
function createTextPreview(text, maxLength = 120) {
const normalized = normalizeQueryText(text, maxLength + 4);
if (!normalized) return "";
return normalized.length > maxLength
? `${normalized.slice(0, maxLength)}...`
: normalized;
}
function roundBlendWeight(value) {
return Math.round((Number(value) || 0) * 1000) / 1000;
}
function uniqueStrings(values = [], maxLength = 400) {
const result = [];
const seen = new Set();
for (const value of values) {
const text = normalizeQueryText(value, maxLength);
const key = text.toLowerCase();
if (!text || seen.has(key)) continue;
seen.add(key);
result.push(text);
}
return result;
}
function parseRecallContextLine(line = "") {
const raw = String(line ?? "").trim();
if (!raw) return null;
const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i);
if (bracketMatch) {
const role = String(bracketMatch[1] || "").toLowerCase();
const text = normalizeQueryText(bracketMatch[2] || "");
return text ? { role, text } : null;
}
const plainMatch = raw.match(
/^(user|assistant|用户|助手|ai)\s*[:]\s*([\s\S]*)$/i,
);
if (!plainMatch) return null;
const roleToken = String(plainMatch[1] || "").toLowerCase();
const role =
roleToken === "assistant" || roleToken === "助手" || roleToken === "ai"
? "assistant"
: "user";
const text = normalizeQueryText(plainMatch[2] || "");
return text ? { role, text } : null;
}
function buildContextQueryBlend(
userMessage,
recentMessages = [],
{
enabled = true,
assistantWeight = 0.2,
previousUserWeight = 0.1,
maxTextLength = 400,
} = {},
) {
const currentText = normalizeQueryText(userMessage, maxTextLength);
const normalizedAssistantWeight = clampRange(assistantWeight, 0.2, 0, 1);
const normalizedPreviousUserWeight = clampRange(
previousUserWeight,
0.1,
0,
1,
);
const currentWeight = Math.max(
0,
1 - normalizedAssistantWeight - normalizedPreviousUserWeight,
);
let assistantText = "";
let previousUserText = "";
const parsedMessages = Array.isArray(recentMessages)
? recentMessages.map((line) => parseRecallContextLine(line)).filter(Boolean)
: [];
for (let index = parsedMessages.length - 1; index >= 0; index--) {
const item = parsedMessages[index];
if (!assistantText && item.role === "assistant") {
assistantText = normalizeQueryText(item.text, maxTextLength);
}
if (
!previousUserText &&
item.role === "user" &&
normalizeQueryText(item.text, maxTextLength).toLowerCase() !==
currentText.toLowerCase()
) {
previousUserText = normalizeQueryText(item.text, maxTextLength);
}
if (assistantText && previousUserText) break;
}
const rawParts = [
{
kind: "currentUser",
label: "当前用户消息",
text: currentText,
weight: enabled ? currentWeight : 1,
},
];
if (enabled && assistantText) {
rawParts.push({
kind: "assistantContext",
label: "最近 assistant 回复",
text: assistantText,
weight: normalizedAssistantWeight,
});
}
if (enabled && previousUserText) {
rawParts.push({
kind: "previousUser",
label: "上一条 user 消息",
text: previousUserText,
weight: normalizedPreviousUserWeight,
});
}
const dedupedParts = [];
const seen = new Set();
for (const part of rawParts) {
const text = normalizeQueryText(part.text, maxTextLength);
const key = text.toLowerCase();
if (!text || seen.has(key)) continue;
seen.add(key);
dedupedParts.push({
...part,
text,
});
}
if (dedupedParts.length === 0) {
return {
active: false,
parts: [],
currentText: "",
assistantText: "",
previousUserText: "",
combinedText: "",
};
}
const totalWeight = dedupedParts.reduce(
(sum, part) => sum + Math.max(0, Number(part.weight) || 0),
0,
);
const normalizedParts = dedupedParts.map((part) => ({
...part,
weight:
totalWeight > 0
? roundBlendWeight((Math.max(0, Number(part.weight) || 0) || 0) / totalWeight)
: roundBlendWeight(1 / dedupedParts.length),
}));
const combinedText =
normalizedParts.length <= 1
? normalizedParts[0]?.text || ""
: normalizedParts
.map((part) => `${part.label}:\n${part.text}`)
.join("\n\n");
return {
active: enabled && normalizedParts.length > 1,
parts: normalizedParts,
currentText: currentText || normalizedParts[0]?.text || "",
assistantText,
previousUserText,
combinedText,
};
}
function buildVectorQueryPlan(
blendPlan,
{ enableMultiIntent = true, maxSegments = 4 } = {},
) {
const plan = [];
let currentSegments = [];
for (const part of blendPlan?.parts || []) {
let queries = [part.text];
if (part.kind === "currentUser" && enableMultiIntent) {
currentSegments = splitIntentSegments(part.text, { maxSegments });
queries = uniqueStrings([
part.text,
...currentSegments.filter((item) => item !== part.text),
]);
} else {
queries = uniqueStrings([part.text]);
}
plan.push({
kind: part.kind,
label: part.label,
weight: part.weight,
queries,
});
}
return {
plan,
currentSegments,
};
}
function buildLexicalQuerySources(
userMessage,
{ enableMultiIntent = true, maxSegments = 4 } = {},
) {
const currentText = normalizeQueryText(userMessage, 400);
const segments = enableMultiIntent
? splitIntentSegments(currentText, { maxSegments })
: [];
return {
sources: uniqueStrings([currentText, ...segments]),
segments,
};
}
function normalizeLexicalText(value = "") {
return normalizeQueryText(value, 600).toLowerCase();
}
function buildLexicalUnits(text = "") {
const normalized = normalizeLexicalText(text);
if (!normalized) return [];
const rawTokens = normalized.match(/[a-z0-9]+|[\u4e00-\u9fff]+/g) || [];
const units = [];
for (const token of rawTokens) {
if (token.length >= 2) {
units.push(token);
}
if (/[\u4e00-\u9fff]/.test(token) && token.length > 2) {
for (let index = 0; index < token.length - 1; index++) {
units.push(token.slice(index, index + 2));
}
}
}
return [...new Set(units)];
}
function computeTokenOverlapScore(sourceUnits = [], targetUnits = []) {
if (!sourceUnits.length || !targetUnits.length) return 0;
const targetSet = new Set(targetUnits);
let overlap = 0;
for (const unit of sourceUnits) {
if (targetSet.has(unit)) {
overlap += 1;
}
}
return overlap / Math.max(1, sourceUnits.length);
}
function scoreFieldMatch(
fieldText,
querySources = [],
{ exact = 1, includes = 0.9, overlap = 0.6 } = {},
) {
const normalizedField = normalizeLexicalText(fieldText);
if (!normalizedField) return 0;
const fieldUnits = buildLexicalUnits(normalizedField);
let best = 0;
for (const sourceText of querySources) {
const normalizedSource = normalizeLexicalText(sourceText);
if (!normalizedSource) continue;
if (normalizedSource === normalizedField) {
best = Math.max(best, exact);
continue;
}
if (
Math.min(normalizedSource.length, normalizedField.length) >= 2 &&
(normalizedSource.includes(normalizedField) ||
normalizedField.includes(normalizedSource))
) {
best = Math.max(best, includes);
}
const overlapScore = computeTokenOverlapScore(
buildLexicalUnits(normalizedSource),
fieldUnits,
);
best = Math.max(best, overlapScore * overlap);
}
return Math.min(1, best);
}
function collectNodeLexicalTexts(node, fieldNames = []) {
const values = [];
for (const fieldName of fieldNames) {
const value = node?.fields?.[fieldName];
if (typeof value === "string" && value.trim()) {
values.push(value.trim());
} else if (Array.isArray(value)) {
for (const item of value) {
if (typeof item === "string" && item.trim()) {
values.push(item.trim());
}
}
}
}
return values;
}
function computeLexicalScore(node, querySources = []) {
if (!node || !Array.isArray(querySources) || querySources.length === 0) {
return 0;
}
const primaryTexts = collectNodeLexicalTexts(node, ["name", "title"]);
const secondaryTexts = collectNodeLexicalTexts(node, [
"summary",
"insight",
"state",
"traits",
"participants",
"status",
]);
const combinedText = [...primaryTexts, ...secondaryTexts].join(" ");
const primaryScore = primaryTexts.reduce(
(best, value) =>
Math.max(
best,
scoreFieldMatch(value, querySources, {
exact: 1,
includes: 0.92,
overlap: 0.72,
}),
),
0,
);
const secondaryScore = secondaryTexts.reduce(
(best, value) =>
Math.max(
best,
scoreFieldMatch(value, querySources, {
exact: 0.82,
includes: 0.68,
overlap: 0.52,
}),
),
0,
);
const tokenScore = scoreFieldMatch(combinedText, querySources, {
exact: 0.65,
includes: 0.55,
overlap: 0.45,
});
if (primaryScore <= 0 && secondaryScore <= 0 && tokenScore <= 0) {
return 0;
}
return Math.min(
1,
Math.max(
primaryScore,
secondaryScore * 0.82,
tokenScore * 0.7,
primaryScore * 0.75 + secondaryScore * 0.35 + tokenScore * 0.2,
),
);
}
function buildLexicalTopHits(scoredNodes = [], maxCount = 5) {
return scoredNodes
.filter((item) => (Number(item?.lexicalScore) || 0) > 0)
.sort((a, b) => {
const lexicalDelta =
(Number(b?.lexicalScore) || 0) - (Number(a?.lexicalScore) || 0);
if (lexicalDelta !== 0) return lexicalDelta;
return (Number(b?.finalScore) || 0) - (Number(a?.finalScore) || 0);
})
.slice(0, Math.max(1, maxCount))
.map((item) => ({
nodeId: item.nodeId,
type: item.node?.type || "",
label:
item.node?.fields?.name ||
item.node?.fields?.title ||
item.node?.fields?.summary ||
item.nodeId,
lexicalScore: Math.round((Number(item.lexicalScore) || 0) * 1000) / 1000,
finalScore: Math.round((Number(item.finalScore) || 0) * 1000) / 1000,
}));
}
function scaleVectorResults(results = [], weight = 1) {
return (Array.isArray(results) ? results : []).map((item) => ({
...item,
score: (Number(item?.score) || 0) * Math.max(0, Number(weight) || 0),
}));
}
/**
* 三层混合检索管线
*
@@ -248,6 +668,21 @@ export async function retrieve({
10,
);
const residualTopK = clampPositiveInt(options.residualTopK, 5);
const enableContextQueryBlend = options.enableContextQueryBlend ?? true;
const contextAssistantWeight = clampRange(
options.contextAssistantWeight,
0.2,
0,
1,
);
const contextPreviousUserWeight = clampRange(
options.contextPreviousUserWeight,
0.1,
0,
1,
);
const enableLexicalBoost = options.enableLexicalBoost ?? true;
const lexicalWeight = clampRange(options.lexicalWeight, 0.18, 0, 10);
let activeNodes = getActiveNodes(graph).filter(
(node) =>
@@ -270,6 +705,29 @@ export async function retrieve({
);
const vectorValidation = validateVectorConfig(embeddingConfig);
const retrievalMeta = createRetrievalMeta(enableLLMRecall);
const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, {
enabled: enableContextQueryBlend,
assistantWeight: contextAssistantWeight,
previousUserWeight: contextPreviousUserWeight,
});
retrievalMeta.queryBlendActive = contextQueryBlend.active;
retrievalMeta.queryBlendParts = (contextQueryBlend.parts || []).map((part) => ({
kind: part.kind,
label: part.label,
weight: part.weight,
text: createTextPreview(part.text),
length: part.text.length,
}));
retrievalMeta.queryBlendWeights = Object.fromEntries(
(contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]),
);
const lexicalQuery = buildLexicalQuerySources(
contextQueryBlend.currentText || userMessage,
{
enableMultiIntent,
maxSegments: multiIntentMaxSegments,
},
);
console.log(
`[ST-BME] 检索开始: ${nodeCount} 个活跃节点${enableVisibility ? " (认知边界已启用)" : ""}`,
);
@@ -299,25 +757,25 @@ export async function retrieve({
const vectorStartedAt = nowMs();
if (enableVectorPrefilter && vectorValidation.valid) {
console.log("[ST-BME] 第1层: 向量预筛");
const segments = enableMultiIntent
? splitIntentSegments(userMessage, {
maxSegments: multiIntentMaxSegments,
})
: [];
const queries = [userMessage, ...segments.filter((item) => item !== userMessage)];
const queryPlan = buildVectorQueryPlan(contextQueryBlend, {
enableMultiIntent,
maxSegments: multiIntentMaxSegments,
});
const groups = [];
retrievalMeta.segmentsUsed = segments;
for (const queryText of queries) {
const results = await vectorPreFilter(
graph,
queryText,
activeNodes,
embeddingConfig,
normalizedTopK,
signal,
);
groups.push(results);
retrievalMeta.segmentsUsed = queryPlan.currentSegments;
for (const part of queryPlan.plan) {
for (const queryText of part.queries) {
const results = await vectorPreFilter(
graph,
queryText,
activeNodes,
embeddingConfig,
normalizedTopK,
signal,
);
groups.push(scaleVectorResults(results, part.weight || 1));
}
}
const merged = mergeVectorResults(
@@ -332,7 +790,12 @@ export async function retrieve({
}
retrievalMeta.timings.vector = roundMs(nowMs() - vectorStartedAt);
exactEntityAnchors.push(...extractEntityAnchors(userMessage, activeNodes));
exactEntityAnchors.push(
...extractEntityAnchors(
contextQueryBlend.currentText || userMessage,
activeNodes,
),
);
supplementalAnchorNodeIds = collectSupplementalAnchorNodeIds(
graph,
vectorResults,
@@ -354,7 +817,7 @@ export async function retrieve({
residualBasisMaxNodes,
);
residualResult = await runResidualRecall({
queryText: userMessage,
queryText: contextQueryBlend.combinedText || userMessage,
graph,
embeddingConfig,
basisNodes,
@@ -514,22 +977,39 @@ export async function retrieve({
for (const [nodeId, scores] of scoreMap) {
const node = getNode(graph, nodeId);
if (!node || node.archived) continue;
const lexicalScore = enableLexicalBoost
? computeLexicalScore(node, lexicalQuery.sources)
: 0;
const finalScore = hybridScore(
{
graphScore: scores.graphScore,
vectorScore: scores.vectorScore,
lexicalScore,
importance: node.importance,
createdTime: node.createdTime,
},
weights,
{
...weights,
lexicalWeight: enableLexicalBoost ? lexicalWeight : 0,
},
);
scoredNodes.push({ nodeId, node, finalScore, ...scores });
scoredNodes.push({
nodeId,
node,
finalScore,
lexicalScore,
...scores,
});
}
scoredNodes.sort((a, b) => b.finalScore - a.finalScore);
retrievalMeta.scoredCandidates = scoredNodes.length;
retrievalMeta.lexicalBoostedNodes = scoredNodes.filter(
(item) => (Number(item.lexicalScore) || 0) > 0,
).length;
retrievalMeta.lexicalTopHits = buildLexicalTopHits(scoredNodes);
retrievalMeta.timings.scoring = roundMs(nowMs() - scoringStartedAt);
let selectedNodeIds;