feat: shared ranking core + prompt node references; recall reuses shared core for base query/vector/diffusion; remove retriever-local duplicate helpers; add regression tests

This commit is contained in:
Youzini-afk
2026-04-12 14:59:22 +08:00
parent 4b4f77caff
commit dc5051f2ef
8 changed files with 1855 additions and 402 deletions

View File

@@ -16,7 +16,7 @@ import {
updateNode,
} from "../graph/graph.js";
import { callLLMForJSON } from "../llm/llm.js";
import { ensureEventTitle, getNodeDisplayName } from "../graph/node-labels.js";
import { ensureEventTitle } from "../graph/node-labels.js";
import {
normalizeMemoryScope,
isObjectiveScope,
@@ -41,8 +41,10 @@ import {
buildTaskLlmPayload,
buildTaskPrompt,
} from "../prompting/prompt-builder.js";
import { createPromptNodeReferenceMap } from "../prompting/prompt-node-references.js";
import { RELATION_TYPES } from "../graph/schema.js";
import { applyTaskRegex } from "../prompting/task-regex.js";
import { rankNodesForTaskContext } from "../retrieval/shared-ranking.js";
import { getSTContextForPrompt, getSTContextSnapshot } from "../host/st-context.js";
import { buildExtractionInputContext } from "./extraction-context.js";
import {
@@ -148,6 +150,62 @@ function resolveExtractPromptStructuredMode(settings) {
return "both";
}
function formatExtractRankingMessage(message = {}) {
const role = String(message?.role || "assistant").trim().toLowerCase() === "user"
? "user"
: "assistant";
const content = String(message?.content || "").trim();
if (!content) return "";
return `[${role}]: ${content}`;
}
function buildExtractRankingQueryText(messages = []) {
const normalizedMessages = Array.isArray(messages) ? messages : [];
const targetLines = normalizedMessages
.filter((message) => message?.isContextOnly !== true)
.map((message) => formatExtractRankingMessage(message))
.filter(Boolean);
if (targetLines.length > 0) {
return targetLines.join("\n");
}
return normalizedMessages
.map((message) => formatExtractRankingMessage(message))
.filter(Boolean)
.join("\n");
}
function buildExtractRelevantNodeReferenceMap(scoredNodes = [], schema = [], maxCount = 6) {
const typeLabelById = new Map(
(Array.isArray(schema) ? schema : []).map((typeDef) => [
String(typeDef?.id || "").trim(),
String(typeDef?.label || typeDef?.id || "").trim(),
]),
);
const relevantNodes = (Array.isArray(scoredNodes) ? scoredNodes : [])
.filter((entry) =>
entry?.node &&
!entry.node.archived &&
((Number(entry?.vectorScore) || 0) > 0 ||
(Number(entry?.graphScore) || 0) > 0 ||
(Number(entry?.lexicalScore) || 0) > 0),
)
.slice(0, Math.max(1, maxCount));
return createPromptNodeReferenceMap(relevantNodes, {
prefix: "G",
maxLength: 28,
buildMeta: ({ entry, node }) => ({
typeLabel:
typeLabelById.get(String(node?.type || "").trim()) ||
String(node?.type || "节点").trim() ||
"节点",
score:
Math.round(
(Number(entry?.weightedScore ?? entry?.finalScore) || 0) * 1000,
) / 1000,
}),
});
}
function isAbortError(error) {
return error?.name === "AbortError";
}
@@ -900,8 +958,31 @@ export async function extractMemories({
? dialogueText
: structuredMessages;
const extractGraphRankingQuery = buildExtractRankingQueryText(structuredMessages);
const extractGraphRanking =
graph?.nodes?.some((node) => !node?.archived) && extractGraphRankingQuery
? await rankNodesForTaskContext({
graph,
userMessage: extractGraphRankingQuery,
recentMessages: [],
embeddingConfig,
signal,
options: {
topK: 12,
diffusionTopK: 48,
enableContextQueryBlend: false,
enableMultiIntent: true,
maxTextLength: 1200,
},
})
: null;
const extractGraphRelevantNodes = buildExtractRelevantNodeReferenceMap(
extractGraphRanking?.scoredNodes,
schema,
);
// 构建当前图概览(让 LLM 知道已有哪些节点,避免重复)
const graphOverview = buildGraphOverview(graph, schema);
const graphOverview = buildGraphOverview(graph, schema, extractGraphRelevantNodes);
// 构建 Schema 描述
const schemaDescription = buildSchemaDescription(schema);
@@ -924,6 +1005,14 @@ export async function extractMemories({
`storyTimeContext=${storyTimeContext ? "present" : "none"}, ` +
`worldbookMode=${String(settings?.extractWorldbookMode || "active")}`,
);
if (extractGraphRanking) {
debugLog(
`[ST-BME][extract-graph] relevantNodes=${extractGraphRelevantNodes.references.length}, ` +
`vectorMergedHits=${Number(extractGraphRanking?.diagnostics?.vectorMergedHits || 0)}, ` +
`diffusionHits=${Number(extractGraphRanking?.diagnostics?.diffusionHits || 0)}, ` +
`lexicalBoostedNodes=${Number(extractGraphRanking?.diagnostics?.lexicalBoostedNodes || 0)}`,
);
}
const extractWorldbookMode = String(settings?.extractWorldbookMode || "active").trim().toLowerCase();
const promptBuild = await buildTaskPrompt(settings, "extract", {
@@ -1753,21 +1842,31 @@ async function generateNodeEmbeddings(graph, embeddingConfig, signal) {
/**
* 构建图谱概览文本(给 LLM 看)
*/
function buildGraphOverview(graph, schema) {
function buildGraphOverview(graph, schema, relevantReferenceMap = null) {
const activeNodes = graph.nodes
.filter((n) => !n.archived)
.sort((a, b) => (a.seq || 0) - (b.seq || 0));
if (activeNodes.length === 0) return "";
const lines = [];
lines.push("### 图谱节点统计");
for (const typeDef of schema) {
const nodesOfType = activeNodes.filter((n) => n.type === typeDef.id);
if (nodesOfType.length === 0) continue;
lines.push(`### ${typeDef.label} (${nodesOfType.length} 个节点)`);
for (const node of nodesOfType.slice(-10)) {
// 只展示最近 10 个
lines.push(` - [${node.id}] ${getNodeDisplayName(node)}`);
lines.push(` - ${typeDef.label}: ${nodesOfType.length}`);
}
const references = Array.isArray(relevantReferenceMap?.references)
? relevantReferenceMap.references
: [];
if (references.length > 0) {
lines.push("", "### 与当前提取片段最相关的既有节点");
for (const reference of references) {
const typeLabel = String(reference?.meta?.typeLabel || reference?.meta?.type || "节点").trim() || "节点";
const label = String(reference?.meta?.label || "—").trim() || "—";
const score = Number(reference?.meta?.score || 0).toFixed(3);
lines.push(` - [${reference.key}|${typeLabel}] ${label} (score=${score})`);
}
}