feat: shared ranking core + prompt node references; recall reuses shared core for base query/vector/diffusion; remove retriever-local duplicate helpers; add regression tests

This commit is contained in:
Youzini-afk
2026-04-12 14:59:22 +08:00
parent 4b4f77caff
commit dc5051f2ef
8 changed files with 1855 additions and 402 deletions

View File

@@ -62,7 +62,7 @@ installResolveHooks([
},
]);
const { createEmptyGraph, addNode, createNode } = await import("../graph/graph.js");
const { addEdge, addNode, createEdge, createEmptyGraph, createNode } = await import("../graph/graph.js");
const { DEFAULT_NODE_SCHEMA } = await import("../graph/schema.js");
const { extractMemories } = await import("../maintenance/extractor.js");
const { appendSummaryEntry } = await import("../graph/summary-state.js");
@@ -466,6 +466,97 @@ function collectAllPromptContent(captured) {
}
// ── Test 7: new settings exist in defaults ──
{
const graph = createEmptyGraph();
const confessionNode = addNode(
graph,
createNode({
type: "event",
seq: 3,
importance: 8,
fields: {
title: "中文告白",
summary: "她认真地要求你再说一遍喜欢她。",
},
}),
);
const relationshipNode = addNode(
graph,
createNode({
type: "thread",
seq: 4,
importance: 7,
fields: {
title: "感情升温",
summary: "两人的关系在这次告白后快速拉近。",
},
}),
);
addEdge(
graph,
createEdge({
fromId: confessionNode.id,
toId: relationshipNode.id,
relation: "supports",
strength: 0.9,
}),
);
let captured = null;
const restore = setTestOverrides({
llm: {
async callLLMForJSON(payload) {
captured = payload;
return { operations: [], cognitionUpdates: [], regionUpdates: {} };
},
},
});
try {
const result = await extractMemories({
graph,
messages: [
{
seq: 10,
role: "user",
content: "中文告白之后,她还是很害羞。",
name: "玩家",
speaker: "玩家",
},
{
seq: 11,
role: "assistant",
content: "这次中文告白让你们的感情升温了。",
name: "艾琳",
speaker: "艾琳",
},
],
startSeq: 10,
endSeq: 11,
schema: DEFAULT_NODE_SCHEMA,
embeddingConfig: null,
settings: { ...defaultSettings },
});
assert.equal(result.success, true);
assert.ok(captured);
const graphStatsBlock = (Array.isArray(captured.promptMessages) ? captured.promptMessages : []).find(
(m) => m.sourceKey === "graphStats",
);
assert.ok(graphStatsBlock, "graphStats block should exist");
const graphStatsContent = String(graphStatsBlock.content || "");
assert.match(graphStatsContent, /### 图谱节点统计/);
assert.match(graphStatsContent, /事件: 1/);
assert.match(graphStatsContent, /主线: 1/);
assert.match(graphStatsContent, /\[G1\|事件\] 中文告白/);
assert.doesNotMatch(graphStatsContent, new RegExp(confessionNode.id.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")));
} finally {
restore();
}
}
// ── Test 8: new settings exist in defaults ──
{
assert.equal(defaultSettings.extractRecentMessageCap, 0);
assert.equal(defaultSettings.extractPromptStructuredMode, "both");

View File

@@ -0,0 +1,54 @@
import assert from "node:assert/strict";
const {
createPromptNodeReferenceMap,
getPromptNodeLabel,
resolvePromptNodeId,
} = await import("../prompting/prompt-node-references.js");
const rawNodeId = "550e8400-e29b-41d4-a716-446655440000";
const map = createPromptNodeReferenceMap(
[
{
nodeId: rawNodeId,
node: {
id: rawNodeId,
type: "event",
fields: {
title: "这是一个非常非常长的节点标题,用于测试提取提示里的标签截断行为",
},
},
score: 0.91,
},
{
node: {
id: "node-2",
type: "thread",
fields: {
summary: "关系持续升温",
},
},
score: 0.77,
},
],
{
prefix: "G",
maxLength: 12,
buildMeta: ({ entry }) => ({
score: entry.score,
}),
},
);
assert.deepEqual(Object.keys(map.keyToNodeId), ["G1", "G2"]);
assert.equal(map.keyToNodeId.G1, rawNodeId);
assert.equal(map.nodeIdToKey[rawNodeId], "G1");
assert.equal(resolvePromptNodeId({ nodeId: rawNodeId }), rawNodeId);
assert.equal(resolvePromptNodeId({ node: { id: "node-2" } }), "node-2");
assert.equal(getPromptNodeLabel({ id: "node-3", fields: { title: "短标题" } }), "短标题");
assert.equal(map.keyToMeta.G1.score, 0.91);
assert.match(map.keyToMeta.G1.label, /^这是一个非常非常长的节…$/);
assert.equal(map.keyToMeta.G2.label, "关系持续升温");
assert.equal(map.keyToMeta.G1.nodeId, rawNodeId);
console.log("prompt-node-references tests passed");

View File

@@ -78,6 +78,485 @@ function createGraphHelpers(graph) {
};
}
function getPromptNodeLabel(node = {}, { maxLength = 32 } = {}) {
const raw = String(
node?.fields?.title ||
node?.fields?.name ||
node?.fields?.summary ||
node?.fields?.insight ||
node?.fields?.belief ||
node?.id ||
"—",
)
.replace(/\s+/g, " ")
.trim();
if (!raw) return "—";
if (!Number.isFinite(maxLength) || maxLength < 2 || raw.length <= maxLength) {
return raw;
}
return `${raw.slice(0, Math.max(1, maxLength - 1)).trimEnd()}`;
}
function createPromptNodeReferenceMap(entries = [], { prefix = "N", buildMeta = null } = {}) {
const keyToNodeId = {};
const keyToMeta = {};
const nodeIdToKey = {};
const references = [];
for (const [index, entry] of (Array.isArray(entries) ? entries : []).entries()) {
const node = entry?.node || entry || {};
const nodeId = String(entry?.nodeId || node?.id || "").trim();
if (!nodeId || nodeIdToKey[nodeId]) continue;
const key = `${String(prefix || "N").trim() || "N"}${references.length + 1}`;
keyToNodeId[key] = nodeId;
nodeIdToKey[nodeId] = key;
keyToMeta[key] = {
nodeId,
type: String(node?.type || ""),
label: getPromptNodeLabel(node),
...((typeof buildMeta === "function"
? buildMeta({ entry, node, nodeId, key, index, label: getPromptNodeLabel(node) })
: {}) || {}),
};
references.push({ key, nodeId, node, meta: keyToMeta[key] });
}
return {
keyToNodeId,
keyToMeta,
nodeIdToKey,
references,
};
}
function normalizeQueryText(value, maxLength = 400) {
const normalized = String(value ?? "")
.replace(/\r\n/g, "\n")
.replace(/\s+/g, " ")
.trim();
if (!normalized) return "";
return normalized.slice(0, Math.max(1, maxLength));
}
function splitIntentSegments(text, { maxSegments = 4, minLength = 1 } = {}) {
const raw = String(text || "").trim();
if (!raw) return [];
const segments = raw
.split(/[,。.;!?\n]+|(?:和|顺便|另外|还有|对了|然后|而且|并且|同时)/)
.map((item) => item.trim())
.filter((item) => item.length >= minLength);
return uniqueStrings(segments).slice(0, Math.max(1, maxSegments));
}
function uniqueStrings(values = [], maxLength = 400) {
const result = [];
const seen = new Set();
for (const value of values) {
const text = normalizeQueryText(value, maxLength);
const key = text.toLowerCase();
if (!text || seen.has(key)) continue;
seen.add(key);
result.push(text);
}
return result;
}
function mergeVectorResults(groups, limit) {
const merged = new Map();
let rawHitCount = 0;
for (const group of groups) {
for (const item of group) {
rawHitCount += 1;
const existing = merged.get(item.nodeId);
if (!existing || item.score > existing.score) {
merged.set(item.nodeId, item);
}
}
}
return {
rawHitCount,
results: [...merged.values()].slice(0, limit),
};
}
function parseContextLine(line = "") {
const raw = String(line ?? "").trim();
if (!raw) return null;
const bracketMatch = raw.match(/^\[(user|assistant)\]\s*:\s*([\s\S]*)$/i);
if (bracketMatch) {
const role = String(bracketMatch[1] || "").toLowerCase();
const text = normalizeQueryText(bracketMatch[2] || "");
return text ? { role, text } : null;
}
const plainMatch = raw.match(/^(user|assistant|用户|助手|ai)\s*[:]\s*([\s\S]*)$/i);
if (!plainMatch) return null;
const roleToken = String(plainMatch[1] || "").toLowerCase();
const role =
roleToken === "assistant" || roleToken === "助手" || roleToken === "ai"
? "assistant"
: "user";
const text = normalizeQueryText(plainMatch[2] || "");
return text ? { role, text } : null;
}
function buildContextQueryBlend(
userMessage,
recentMessages = [],
{
enabled = true,
assistantWeight = 0.2,
previousUserWeight = 0.1,
maxTextLength = 400,
} = {},
) {
const currentText = normalizeQueryText(userMessage, maxTextLength);
let assistantText = "";
let previousUserText = "";
const parsedMessages = Array.isArray(recentMessages)
? recentMessages.map((line) => parseContextLine(line)).filter(Boolean)
: [];
for (let index = parsedMessages.length - 1; index >= 0; index -= 1) {
const item = parsedMessages[index];
if (!assistantText && item.role === "assistant") {
assistantText = normalizeQueryText(item.text, maxTextLength);
}
if (
!previousUserText &&
item.role === "user" &&
normalizeQueryText(item.text, maxTextLength).toLowerCase() !==
currentText.toLowerCase()
) {
previousUserText = normalizeQueryText(item.text, maxTextLength);
}
if (assistantText && previousUserText) break;
}
const currentWeight = Math.max(
0,
1 - Number(assistantWeight || 0) - Number(previousUserWeight || 0),
);
const rawParts = [
{
kind: "currentUser",
label: "当前用户消息",
text: currentText,
weight: enabled ? currentWeight : 1,
},
];
if (enabled && assistantText) {
rawParts.push({
kind: "assistantContext",
label: "最近 assistant 回复",
text: assistantText,
weight: Number(assistantWeight || 0),
});
}
if (enabled && previousUserText) {
rawParts.push({
kind: "previousUser",
label: "上一条 user 消息",
text: previousUserText,
weight: Number(previousUserWeight || 0),
});
}
const dedupedParts = [];
const seen = new Set();
for (const part of rawParts) {
const text = normalizeQueryText(part.text, maxTextLength);
const key = text.toLowerCase();
if (!text || seen.has(key)) continue;
seen.add(key);
dedupedParts.push({ ...part, text });
}
const totalWeight = dedupedParts.reduce(
(sum, part) => sum + Math.max(0, Number(part.weight) || 0),
0,
);
const parts = dedupedParts.map((part) => ({
...part,
weight:
totalWeight > 0
? Math.round((Math.max(0, Number(part.weight) || 0) / totalWeight) * 1000) /
1000
: Math.round((1 / Math.max(1, dedupedParts.length)) * 1000) / 1000,
}));
return {
active: enabled && parts.length > 1,
parts,
currentText: currentText || parts[0]?.text || "",
assistantText,
previousUserText,
combinedText:
parts.length <= 1
? parts[0]?.text || ""
: parts.map((part) => `${part.label}:\n${part.text}`).join("\n\n"),
};
}
function buildVectorQueryPlan(
blendPlan,
{ enableMultiIntent = true, maxSegments = 4 } = {},
) {
const plan = [];
let currentSegments = [];
for (const part of blendPlan?.parts || []) {
let queries = [part.text];
if (part.kind === "currentUser" && enableMultiIntent) {
currentSegments = splitIntentSegments(part.text, { maxSegments });
queries = uniqueStrings([
part.text,
...currentSegments.filter((item) => item !== part.text),
]);
} else {
queries = uniqueStrings([part.text]);
}
plan.push({
kind: part.kind,
label: part.label,
weight: part.weight,
queries,
});
}
return {
plan,
currentSegments,
};
}
function buildLexicalQuerySources(
userMessage,
{ enableMultiIntent = true, maxSegments = 4 } = {},
) {
const currentText = normalizeQueryText(userMessage, 400);
const segments = enableMultiIntent
? splitIntentSegments(currentText, { maxSegments })
: [];
return {
sources: uniqueStrings([currentText, ...segments]),
segments,
};
}
function computeLexicalScoreForShared(node, querySources = []) {
const haystack = String(
node?.fields?.name || node?.fields?.title || node?.fields?.summary || "",
).toLowerCase();
if (!haystack) return 0;
for (const sourceText of querySources) {
const normalizedSource = String(sourceText || "").toLowerCase();
if (normalizedSource && haystack.includes(normalizedSource.split(/\s+/)[0])) {
return 1;
}
}
return 0;
}
function extractEntityAnchors(userMessage, activeNodes = []) {
const anchors = [];
const seen = new Set();
for (const node of activeNodes) {
const candidates = [node?.fields?.name, node?.fields?.title]
.filter((value) => typeof value === "string")
.map((value) => value.trim())
.filter((value) => value.length >= 2);
for (const candidate of candidates) {
if (!String(userMessage || "").includes(candidate)) continue;
const key = `${node.id}:${candidate}`;
if (seen.has(key)) continue;
seen.add(key);
anchors.push({ nodeId: node.id, entity: candidate });
break;
}
}
return anchors;
}
async function rankNodesForTaskContext({
graph,
userMessage,
recentMessages = [],
embeddingConfig,
options = {},
} = {}) {
const activeNodes = Array.isArray(options.activeNodes)
? options.activeNodes.filter((node) => node && !node.archived)
: (graph?.nodes || []).filter((node) => node && !node.archived);
const topK = Math.max(1, Math.floor(Number(options.topK) || 20));
const diffusionTopK = Math.max(1, Math.floor(Number(options.diffusionTopK) || 100));
const enableVectorPrefilter = options.enableVectorPrefilter ?? true;
const enableGraphDiffusion = options.enableGraphDiffusion ?? true;
const enableContextQueryBlend = options.enableContextQueryBlend ?? true;
const enableMultiIntent = options.enableMultiIntent ?? true;
const multiIntentMaxSegments = Math.max(
1,
Math.floor(Number(options.multiIntentMaxSegments) || 4),
);
const contextQueryBlend = buildContextQueryBlend(userMessage, recentMessages, {
enabled: enableContextQueryBlend,
assistantWeight: Number(options.contextAssistantWeight ?? 0.2),
previousUserWeight: Number(options.contextPreviousUserWeight ?? 0.1),
maxTextLength: Number(options.maxTextLength || 400),
});
const queryPlan = buildVectorQueryPlan(contextQueryBlend, {
enableMultiIntent,
maxSegments: multiIntentMaxSegments,
});
const lexicalQuery = buildLexicalQuerySources(
contextQueryBlend.currentText || userMessage,
{
enableMultiIntent,
maxSegments: multiIntentMaxSegments,
},
);
const diagnostics = {
queryBlendActive: contextQueryBlend.active,
queryBlendParts: (contextQueryBlend.parts || []).map((part) => ({
kind: part.kind,
label: part.label,
weight: part.weight,
text: part.text,
length: part.text.length,
})),
queryBlendWeights: Object.fromEntries(
(contextQueryBlend.parts || []).map((part) => [part.kind, part.weight]),
),
segmentsUsed: [...(queryPlan.currentSegments || [])],
vectorValidation: { valid: true },
vectorHits: 0,
vectorMergedHits: 0,
seedCount: 0,
diffusionHits: 0,
temporalSyntheticEdgeCount: 0,
teleportAlpha: Number(options.teleportAlpha ?? 0.15) || 0.15,
lexicalBoostedNodes: 0,
lexicalTopHits: [],
skipReasons: [],
timings: { vector: 0, diffusion: 0 },
};
let vectorResults = [];
if (enableVectorPrefilter) {
const groups = [];
for (const part of queryPlan.plan) {
for (const queryText of part.queries) {
state.vectorCalls.push({ topK, message: queryText });
const results = [
{ nodeId: "rule-1", score: 0.9 },
{ nodeId: "rule-2", score: 0.8 },
{ nodeId: "rule-3", score: 0.7 },
].map((item) => ({
...item,
score: item.score * Math.max(0, Number(part.weight) || 0),
}));
groups.push(results);
}
}
const merged = mergeVectorResults(groups, Math.max(topK * 2, 24));
diagnostics.vectorHits = merged.rawHitCount;
diagnostics.vectorMergedHits = merged.results.length;
vectorResults = merged.results;
}
const exactEntityAnchors = extractEntityAnchors(
contextQueryBlend.currentText || userMessage,
activeNodes,
);
let diffusionResults = [];
if (enableGraphDiffusion) {
const seedMap = new Map();
for (const item of vectorResults) {
seedMap.set(item.nodeId, Math.max(seedMap.get(item.nodeId) || 0, item.score));
}
for (const item of exactEntityAnchors) {
seedMap.set(item.nodeId, Math.max(seedMap.get(item.nodeId) || 0, 2.0));
}
const uniqueSeeds = [...seedMap.entries()].map(([id, energy]) => ({ id, energy }));
diagnostics.seedCount = uniqueSeeds.length;
if (uniqueSeeds.length > 0) {
state.diffusionCalls.push({
seeds: uniqueSeeds,
options: {
maxSteps: 2,
decayFactor: 0.6,
topK: diffusionTopK,
teleportAlpha: diagnostics.teleportAlpha,
},
});
diffusionResults = [
{ nodeId: "rule-2", energy: 1.2 },
{ nodeId: "rule-3", energy: 0.9 },
];
}
}
diagnostics.diffusionHits = diffusionResults.length;
const scoreMap = new Map();
for (const item of vectorResults) {
scoreMap.set(item.nodeId, {
graphScore: scoreMap.get(item.nodeId)?.graphScore || 0,
vectorScore: item.score,
});
}
for (const item of diffusionResults) {
scoreMap.set(item.nodeId, {
graphScore: item.energy,
vectorScore: scoreMap.get(item.nodeId)?.vectorScore || 0,
});
}
if (scoreMap.size === 0) {
for (const node of activeNodes) {
scoreMap.set(node.id, { graphScore: 0, vectorScore: 0 });
}
}
const scoredNodes = [...scoreMap.entries()].map(([nodeId, scores]) => {
const node = activeNodes.find((item) => item.id === nodeId) || null;
const lexicalScore = computeLexicalScoreForShared(node, lexicalQuery.sources);
return {
nodeId,
node,
graphScore: scores.graphScore,
vectorScore: scores.vectorScore,
lexicalScore,
finalScore:
Number(scores.graphScore || 0) +
Number(scores.vectorScore || 0) +
Number(lexicalScore || 0) +
Number(node?.importance || 0),
weightedScore:
Number(scores.graphScore || 0) +
Number(scores.vectorScore || 0) +
Number(lexicalScore || 0) +
Number(node?.importance || 0),
};
});
diagnostics.lexicalBoostedNodes = scoredNodes.filter(
(item) => (Number(item.lexicalScore) || 0) > 0,
).length;
diagnostics.lexicalTopHits = scoredNodes
.filter((item) => (Number(item.lexicalScore) || 0) > 0)
.slice(0, 5)
.map((item) => ({
nodeId: item.nodeId,
label: item.node?.fields?.name || item.node?.fields?.title || item.nodeId,
lexicalScore: item.lexicalScore,
finalScore: item.finalScore,
}));
return {
activeNodes,
contextQueryBlend,
queryPlan,
lexicalQuery,
vectorResults,
exactEntityAnchors,
diffusionResults,
scoredNodes,
diagnostics,
};
}
const schema = [{ id: "rule", label: "规则", alwaysInject: false }];
const state = {
@@ -93,6 +572,9 @@ const graph = createGraph();
const helpers = createGraphHelpers(graph);
const retrieve = await loadRetrieve({
...helpers,
createPromptNodeReferenceMap,
getPromptNodeLabel,
rankNodesForTaskContext,
STORY_TEMPORAL_BUCKETS: {
CURRENT: "current",
ADJACENT_PAST: "adjacentPast",

175
tests/shared-ranking.mjs Normal file
View File

@@ -0,0 +1,175 @@
import assert from "node:assert/strict";
import {
installResolveHooks,
toDataModuleUrl,
} from "./helpers/register-hooks-compat.mjs";
const extensionsShimSource = [
"export const extension_settings = {};",
"export function getContext() {",
" return {",
" chat: [],",
" chatMetadata: {},",
" extensionSettings: {},",
" powerUserSettings: {},",
" characters: {},",
" characterId: null,",
" name1: '玩家',",
" name2: '艾琳',",
" chatId: 'test-chat',",
" };",
"}",
].join("\n");
const scriptShimSource = [
"export function getRequestHeaders() {",
" return {};",
"}",
"export function substituteParamsExtended(value) {",
" return String(value ?? '');",
"}",
].join("\n");
installResolveHooks([
{
specifiers: [
"../../../extensions.js",
"../../../../extensions.js",
"../../../../../extensions.js",
],
url: toDataModuleUrl(extensionsShimSource),
},
{
specifiers: [
"../../../../script.js",
"../../../../../script.js",
],
url: toDataModuleUrl(scriptShimSource),
},
]);
const { addEdge, addNode, createEdge, createEmptyGraph, createNode } = await import(
"../graph/graph.js"
);
const { rankNodesForTaskContext } = await import("../retrieval/shared-ranking.js");
function setTestOverrides(overrides = {}) {
globalThis.__stBmeTestOverrides = overrides;
return () => {
delete globalThis.__stBmeTestOverrides;
};
}
const graph = createEmptyGraph();
const confession = addNode(
graph,
createNode({
type: "event",
seq: 10,
importance: 8,
fields: {
title: "中文告白",
summary: "她认真地说喜欢你,并要求你再说一遍。",
},
}),
);
const dateEvent = addNode(
graph,
createNode({
type: "event",
seq: 11,
importance: 4,
fields: {
title: "节日约会",
summary: "她们一起逛街吃饭。",
},
}),
);
const relationship = addNode(
graph,
createNode({
type: "thread",
seq: 12,
importance: 7,
fields: {
title: "感情升温",
summary: "两人的恋爱关系快速升温。",
},
}),
);
confession.embedding = [1, 0.3, 0.1];
dateEvent.embedding = [0.2, 0.9, 0.1];
relationship.embedding = [0.8, 0.6, 0.2];
addEdge(
graph,
createEdge({
fromId: confession.id,
toId: relationship.id,
relation: "supports",
strength: 0.9,
}),
);
const graphBefore = JSON.stringify(graph);
const restore = setTestOverrides({
embedding: {
async embedText() {
return [1, 0.5, 0.25];
},
searchSimilar(_queryVec, candidates) {
assert.ok(candidates.some((item) => item.nodeId === confession.id));
return [
{ nodeId: confession.id, score: 0.97 },
{ nodeId: dateEvent.id, score: 0.23 },
];
},
},
});
try {
const config = {
mode: "direct",
source: "direct",
apiUrl: "https://example.com/v1",
apiKey: "",
model: "test-embedding",
};
const first = await rankNodesForTaskContext({
graph,
userMessage: "[user]: 中文告白后的关系进展",
embeddingConfig: config,
options: {
enableContextQueryBlend: false,
topK: 8,
diffusionTopK: 16,
},
});
const second = await rankNodesForTaskContext({
graph,
userMessage: "[user]: 中文告白后的关系进展",
embeddingConfig: config,
options: {
enableContextQueryBlend: false,
topK: 8,
diffusionTopK: 16,
},
});
assert.equal(JSON.stringify(graph), graphBefore, "shared ranking should be side-effect-free");
assert.equal(first.scoredNodes[0]?.nodeId, confession.id);
assert.equal(second.scoredNodes[0]?.nodeId, confession.id);
assert.deepEqual(
first.scoredNodes.map((item) => item.nodeId),
second.scoredNodes.map((item) => item.nodeId),
"ranking order should stay deterministic under fixed inputs",
);
const propagated = first.scoredNodes.find((item) => item.nodeId === relationship.id);
assert.ok(propagated, "diffusion should surface connected relationship node");
assert.ok((Number(propagated?.graphScore) || 0) > 0, "connected node should receive graph diffusion score");
assert.equal(first.diagnostics.vectorMergedHits, 2);
assert.ok(first.diagnostics.diffusionHits >= 1);
} finally {
restore();
}
console.log("shared-ranking tests passed");