feat(authority): add recall candidate provider

This commit is contained in:
Youzini-afk
2026-04-28 15:01:14 +08:00
parent 93c562015f
commit 6c8c56df62
6 changed files with 1005 additions and 28 deletions

View File

@@ -0,0 +1,256 @@
import assert from "node:assert/strict";
import { addNode, createEmptyGraph, createNode } from "../graph/graph.js";
import {
installResolveHooks,
toDataModuleUrl,
} from "./helpers/register-hooks-compat.mjs";
installResolveHooks([
{
specifiers: ["../../../../../script.js"],
url: toDataModuleUrl("export function getRequestHeaders() { return {}; }"),
},
{
specifiers: ["../../../../extensions.js"],
url: toDataModuleUrl("export const extension_settings = { st_bme: {} };"),
},
]);
const { normalizeAuthorityVectorConfig } = await import(
"../vector/authority-vector-primary-adapter.js"
);
const { resolveAuthorityRecallCandidates } = await import(
"../retrieval/authority-candidate-provider.js"
);
function createRecallGraph() {
const graph = createEmptyGraph();
graph.historyState.chatId = "chat-authority-candidates";
graph.vectorIndexState.collectionId = "st-bme:chat-authority-candidates:nodes";
const first = createNode({
type: "event",
seq: 10,
fields: { title: "Alice enters the archive", summary: "Alice reaches the archive gate" },
importance: 6,
scope: {
layer: "objective",
ownerType: "",
ownerId: "",
ownerName: "",
bucket: "objectiveGlobal",
regionKey: "archive",
},
});
first.id = "node-archive";
first.storySegmentId = "seg-archive";
const second = createNode({
type: "event",
seq: 11,
fields: { title: "Bob opens the vault", summary: "Bob unlocks the hidden vault" },
importance: 7,
scope: {
layer: "objective",
ownerType: "",
ownerId: "",
ownerName: "",
bucket: "objectiveGlobal",
regionKey: "archive",
},
});
second.id = "node-vault";
second.storySegmentId = "seg-archive";
const third = createNode({
type: "pov_memory",
seq: 12,
fields: { title: "Alice remembers the key", summary: "Alice knows where the silver key is" },
importance: 9,
scope: {
layer: "pov",
ownerType: "character",
ownerId: "Alice",
ownerName: "Alice",
bucket: "characterPov",
regionKey: "archive",
},
});
third.id = "node-alice-memory";
third.storySegmentId = "seg-archive";
const fourth = createNode({
type: "event",
seq: 6,
fields: { title: "Market rumor", summary: "A rumor spreads in the market" },
importance: 2,
scope: {
layer: "objective",
ownerType: "",
ownerId: "",
ownerName: "",
bucket: "objectiveGlobal",
regionKey: "market",
},
});
fourth.id = "node-market";
fourth.storySegmentId = "seg-market";
addNode(graph, first);
addNode(graph, second);
addNode(graph, third);
addNode(graph, fourth);
return { graph, nodes: [first, second, third, fourth] };
}
function createMockTriviumClient({ failFilter = false, failSearch = false, failNeighbors = false } = {}) {
const calls = [];
return {
calls,
async filterWhere(payload = {}) {
calls.push(["filterWhere", payload]);
if (failFilter) {
throw new Error("filter-down");
}
return {
items: [
{ externalId: "node-archive" },
{ payload: { nodeId: "node-alice-memory" } },
],
};
},
async search(payload = {}) {
calls.push(["search", payload]);
if (failSearch) {
throw new Error("search-down");
}
return {
results: [
{ nodeId: "node-alice-memory", score: 0.96 },
{ nodeId: "node-vault", score: 0.88 },
{ nodeId: "node-outside", score: 0.77 },
],
};
},
async neighbors(payload = {}) {
calls.push(["neighbors", payload]);
if (failNeighbors) {
throw new Error("neighbors-down");
}
return {
neighbors: [
{ fromId: "node-alice-memory", toId: "node-vault" },
{ fromId: "node-alice-memory", toId: "node-archive" },
],
};
},
};
}
{
const { graph, nodes } = createRecallGraph();
const triviumClient = createMockTriviumClient();
const config = normalizeAuthorityVectorConfig(
{
authorityBaseUrl: "/api/plugins/authority",
authorityVectorFailOpen: true,
},
{ triviumClient },
);
const result = await resolveAuthorityRecallCandidates({
graph,
userMessage: "Alice 现在在 archive 里找 silver key 吗?",
recentMessages: ["assistant: Alice just reached the archive gate."],
embeddingConfig: config,
availableNodes: nodes,
activeRegion: "archive",
activeStoryContext: {
activeSegmentId: "seg-archive",
},
activeRecallOwnerKeys: ["character:Alice"],
sceneOwnerCandidates: [
{
ownerKey: "character:Alice",
ownerName: "Alice",
},
],
options: {
enabled: true,
topK: 4,
maxRecallNodes: 2,
limit: 6,
neighborLimit: 2,
minimumUsedCandidateCount: 2,
enableMultiIntent: true,
},
});
assert.equal(result.available, true);
assert.equal(result.used, true);
assert.deepEqual(
result.candidateNodes.map((node) => node.id),
["node-alice-memory", "node-vault", "node-archive"],
);
assert.equal(result.diagnostics.filteredCount, 2);
assert.equal(result.diagnostics.searchHits, 2);
assert.equal(result.diagnostics.neighborCount, 1);
const filterCall = triviumClient.calls.find(([name]) => name === "filterWhere");
assert.equal(filterCall?.[1]?.filters?.archived, false);
assert.deepEqual(filterCall?.[1]?.filters?.regionKeys, ["archive"]);
assert.deepEqual(filterCall?.[1]?.filters?.ownerKeys, ["character:Alice"]);
assert.deepEqual(filterCall?.[1]?.filters?.storySegmentIds, ["seg-archive"]);
const searchCall = triviumClient.calls.find(([name]) => name === "search");
assert.ok(Array.isArray(searchCall?.[1]?.candidateIds));
assert.ok(searchCall?.[1]?.candidateIds.includes("node-alice-memory"));
const neighborCall = triviumClient.calls.find(([name]) => name === "neighbors");
assert.deepEqual(neighborCall?.[1]?.nodeIds, ["node-alice-memory", "node-vault"]);
}
{
const { graph, nodes } = createRecallGraph();
const triviumClient = createMockTriviumClient({
failFilter: true,
failSearch: true,
failNeighbors: true,
});
const config = normalizeAuthorityVectorConfig(
{
authorityBaseUrl: "/api/plugins/authority",
authorityVectorFailOpen: true,
},
{ triviumClient },
);
const result = await resolveAuthorityRecallCandidates({
graph,
userMessage: "archive",
recentMessages: [],
embeddingConfig: config,
availableNodes: nodes,
activeRegion: "archive",
activeStoryContext: {
activeSegmentId: "seg-archive",
},
activeRecallOwnerKeys: ["character:Alice"],
sceneOwnerCandidates: [
{
ownerKey: "character:Alice",
ownerName: "Alice",
},
],
options: {
enabled: true,
topK: 4,
maxRecallNodes: 2,
limit: 6,
neighborLimit: 2,
minimumUsedCandidateCount: 2,
},
});
assert.equal(result.available, true);
assert.equal(result.used, false);
assert.deepEqual(result.candidateNodes, []);
assert.match(result.diagnostics.fallbackReason, /authority-candidate-(filter|search|neighbors)-failed/);
}
console.log("authority-recall-candidates tests passed");

View File

@@ -17,11 +17,12 @@ installResolveHooks([
]);
const {
findSimilarNodesByText,
filterAuthorityTriviumNodes,
isAuthorityVectorConfig,
normalizeAuthorityVectorConfig,
syncGraphVectorIndex,
} = await import("../vector/vector-index.js");
queryAuthorityTriviumNeighbors,
} = await import("../vector/authority-vector-primary-adapter.js");
const { findSimilarNodesByText: findSimilarNodesByTextFromIndex, syncGraphVectorIndex: syncGraphVectorIndexFromIndex } = await import("../vector/vector-index.js");
function createAuthorityVectorGraph() {
const graph = createEmptyGraph();
@@ -86,6 +87,24 @@ function createMockTriviumClient({ failBulkUpsert = false } = {}) {
],
};
},
async filterWhere(payload) {
calls.push(["filterWhere", payload]);
return {
items: [
{ externalId: "node-a" },
{ payload: { nodeId: "node-b" } },
],
};
},
async neighbors(payload) {
calls.push(["neighbors", payload]);
return {
neighbors: [
{ fromId: "node-a", toId: "node-b" },
{ fromId: "node-a", toId: "node-c" },
],
};
},
async stat(payload) {
calls.push(["stat", payload]);
return { ok: true };
@@ -103,7 +122,7 @@ assert.equal(isAuthorityVectorConfig(config), true);
{
const { graph, first, second } = createAuthorityVectorGraph();
const triviumClient = createMockTriviumClient();
const result = await syncGraphVectorIndex(graph, config, {
const result = await syncGraphVectorIndexFromIndex(graph, config, {
chatId: "chat-authority-vector",
purge: true,
triviumClient,
@@ -134,13 +153,13 @@ assert.equal(isAuthorityVectorConfig(config), true);
const { graph, first, second } = createAuthorityVectorGraph();
const triviumClient = createMockTriviumClient();
const queryConfig = { ...config, triviumClient };
await syncGraphVectorIndex(graph, queryConfig, {
await syncGraphVectorIndexFromIndex(graph, queryConfig, {
chatId: "chat-authority-vector",
purge: true,
triviumClient,
});
const results = await findSimilarNodesByText(
const results = await findSimilarNodesByTextFromIndex(
graph,
"archive door",
queryConfig,
@@ -158,7 +177,7 @@ assert.equal(isAuthorityVectorConfig(config), true);
{
const { graph } = createAuthorityVectorGraph();
const triviumClient = createMockTriviumClient({ failBulkUpsert: true });
const result = await syncGraphVectorIndex(graph, config, {
const result = await syncGraphVectorIndexFromIndex(graph, config, {
chatId: "chat-authority-vector",
purge: true,
triviumClient,
@@ -171,4 +190,35 @@ assert.equal(isAuthorityVectorConfig(config), true);
assert.match(graph.vectorIndexState.lastWarning, /Authority Trivium 同步失败/);
}
{
const triviumClient = createMockTriviumClient();
const queryConfig = { ...config, triviumClient };
const filteredIds = await filterAuthorityTriviumNodes(queryConfig, {
collectionId: "authority-filter",
chatId: "chat-authority-vector",
limit: 8,
filters: {
archived: false,
ownerKeys: ["character:Alice"],
},
});
assert.deepEqual(filteredIds, ["node-a", "node-b"]);
const filterCall = triviumClient.calls.find(([name]) => name === "filterWhere");
assert.equal(filterCall?.[1]?.collectionId, "authority-filter");
assert.equal(filterCall?.[1]?.filters?.ownerKeys?.[0], "character:Alice");
}
{
const triviumClient = createMockTriviumClient();
const queryConfig = { ...config, triviumClient };
const neighborIds = await queryAuthorityTriviumNeighbors(queryConfig, ["node-a"], {
collectionId: "authority-filter",
chatId: "chat-authority-vector",
limit: 4,
});
assert.deepEqual(neighborIds, ["node-b", "node-c"]);
const neighborCall = triviumClient.calls.find(([name]) => name === "neighbors");
assert.deepEqual(neighborCall?.[1]?.nodeIds, ["node-a"]);
}
console.log("authority-vector-primary tests passed");

View File

@@ -435,6 +435,7 @@ async function rankNodesForTaskContext({
skipReasons: [],
timings: { vector: 0, diffusion: 0 },
};
const activeNodeIds = new Set(activeNodes.map((node) => node.id));
let vectorResults = [];
if (enableVectorPrefilter) {
@@ -446,10 +447,12 @@ async function rankNodesForTaskContext({
{ nodeId: "rule-1", score: 0.9 },
{ nodeId: "rule-2", score: 0.8 },
{ nodeId: "rule-3", score: 0.7 },
].map((item) => ({
...item,
score: item.score * Math.max(0, Number(part.weight) || 0),
}));
]
.filter((item) => activeNodeIds.has(item.nodeId))
.map((item) => ({
...item,
score: item.score * Math.max(0, Number(part.weight) || 0),
}));
groups.push(results);
}
}
@@ -487,7 +490,7 @@ async function rankNodesForTaskContext({
diffusionResults = [
{ nodeId: "rule-2", energy: 1.2 },
{ nodeId: "rule-3", energy: 0.9 },
];
].filter((item) => activeNodeIds.has(item.nodeId));
}
}
diagnostics.diffusionHits = diffusionResults.length;
@@ -566,6 +569,10 @@ const state = {
llmCandidateCount: 0,
llmResponse: { selected_keys: ["R1", "R2"] },
llmOptions: [],
authorityCandidateCalls: [],
authorityCandidateEnabled: false,
authorityCandidateNodeIds: [],
authorityCandidateDiagnostics: null,
};
const graph = createGraph();
@@ -575,6 +582,80 @@ const retrieve = await loadRetrieve({
createPromptNodeReferenceMap,
getPromptNodeLabel,
rankNodesForTaskContext,
async resolveAuthorityRecallCandidates({
availableNodes = [],
activeRegion = "",
activeStoryContext = {},
activeRecallOwnerKeys = [],
options = {},
} = {}) {
state.authorityCandidateCalls.push({
availableNodeIds: availableNodes.map((node) => node.id),
activeRegion,
activeStorySegmentId: String(activeStoryContext?.activeSegmentId || ""),
activeRecallOwnerKeys: [...(activeRecallOwnerKeys || [])],
minimumUsedCandidateCount: Number(options.minimumUsedCandidateCount || 0) || 0,
});
if (!state.authorityCandidateEnabled) {
return {
available: false,
used: false,
candidateNodes: [],
diagnostics: {
provider: "authority-trivium",
candidateCount: 0,
filteredCount: 0,
searchHits: 0,
neighborCount: 0,
queryTexts: [],
fallbackReason: "authority-vector-unavailable",
timings: {
total: 0,
filter: 0,
search: 0,
neighbors: 0,
},
},
};
}
const requestedIds = Array.isArray(state.authorityCandidateNodeIds)
? state.authorityCandidateNodeIds
: [];
const candidateNodes = availableNodes.filter((node) => requestedIds.includes(node.id));
const minimumUsedCandidateCount = Number(options.minimumUsedCandidateCount || 0) || 0;
const used =
candidateNodes.length > 0 &&
candidateNodes.length < availableNodes.length &&
candidateNodes.length >= minimumUsedCandidateCount;
const diagnostics = {
provider: "authority-trivium",
candidateCount: candidateNodes.length,
filteredCount: candidateNodes.length,
searchHits: candidateNodes.length,
neighborCount: 0,
queryTexts: ["authority-candidate-query"],
fallbackReason: used
? ""
: candidateNodes.length === 0
? "authority-candidate-empty"
: candidateNodes.length >= availableNodes.length
? "authority-candidate-not-reduced"
: "authority-candidate-too-small",
timings: {
total: 1,
filter: 0.2,
search: 0.4,
neighbors: 0,
},
...(state.authorityCandidateDiagnostics || {}),
};
return {
available: true,
used,
candidateNodes: used ? candidateNodes : [],
diagnostics,
};
},
STORY_TEMPORAL_BUCKETS: {
CURRENT: "current",
ADJACENT_PAST: "adjacentPast",
@@ -902,6 +983,44 @@ assert.equal(state.diffusionCalls.length, 0);
assert.equal(state.llmCalls.length, 0);
assert.deepEqual(Array.from(noStageResult.selectedNodeIds), ["rule-2", "rule-1"]);
state.authorityCandidateCalls.length = 0;
state.authorityCandidateEnabled = true;
state.authorityCandidateNodeIds = ["rule-2"];
state.authorityCandidateDiagnostics = null;
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
const authorityCandidateResult = await retrieve({
graph,
userMessage: "只看规则二",
recentMessages: ["assistant: 请聚焦最新规则。"],
embeddingConfig: {
mode: "authority",
source: "authority-trivium",
failOpen: true,
},
schema,
options: {
topK: 2,
maxRecallNodes: 2,
enableVectorPrefilter: true,
enableGraphDiffusion: false,
enableLLMRecall: false,
authorityCandidateMinCount: 1,
},
settings: {
authorityGraphQueryEnabled: true,
},
});
assert.equal(state.authorityCandidateCalls.length, 1);
assert.deepEqual(state.authorityCandidateCalls[0].availableNodeIds, ["rule-1", "rule-2", "rule-3"]);
assert.equal(authorityCandidateResult.meta.retrieval.authorityCandidateUsed, true);
assert.equal(authorityCandidateResult.meta.retrieval.authorityCandidateCount, 1);
assert.equal(authorityCandidateResult.meta.retrieval.rankingNodeCount, 1);
assert.deepEqual(Array.from(authorityCandidateResult.selectedNodeIds), ["rule-2"]);
state.authorityCandidateEnabled = false;
state.authorityCandidateNodeIds = [];
state.authorityCandidateDiagnostics = null;
state.vectorCalls.length = 0;
await retrieve({
graph,