mirror of
https://github.com/Youzini-afk/ST-Bionic-Memory-Ecology.git
synced 2026-05-15 22:30:38 +08:00
450 lines
12 KiB
JavaScript
450 lines
12 KiB
JavaScript
import assert from "node:assert/strict";
|
||
import fs from "node:fs/promises";
|
||
import path from "node:path";
|
||
import { fileURLToPath } from "node:url";
|
||
import vm from "node:vm";
|
||
|
||
async function loadRetrieve(stubs) {
|
||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||
const retrieverPath = path.resolve(__dirname, "../retriever.js");
|
||
const source = await fs.readFile(retrieverPath, "utf8");
|
||
const transformed = `${source
|
||
.replace(/^import[\s\S]*?from\s+["'][^"']+["'];\r?\n/gm, "")
|
||
.replace("export async function retrieve", "async function retrieve")}
|
||
this.retrieve = retrieve;
|
||
`;
|
||
|
||
const context = vm.createContext({
|
||
console: { log() {}, error() {}, warn() {} },
|
||
...stubs,
|
||
});
|
||
new vm.Script(transformed).runInContext(context);
|
||
return context.retrieve;
|
||
}
|
||
|
||
function createGraph() {
|
||
const nodes = [
|
||
{
|
||
id: "rule-1",
|
||
type: "rule",
|
||
importance: 9,
|
||
createdTime: 1,
|
||
archived: false,
|
||
fields: { title: "规则一" },
|
||
seqRange: [1, 1],
|
||
},
|
||
{
|
||
id: "rule-2",
|
||
type: "rule",
|
||
importance: 7,
|
||
createdTime: 2,
|
||
archived: false,
|
||
fields: { title: "规则二" },
|
||
seqRange: [2, 2],
|
||
},
|
||
{
|
||
id: "rule-3",
|
||
type: "rule",
|
||
importance: 3,
|
||
createdTime: 3,
|
||
archived: false,
|
||
fields: { title: "规则三" },
|
||
seqRange: [3, 3],
|
||
},
|
||
];
|
||
return { nodes, edges: [] };
|
||
}
|
||
|
||
function createGraphHelpers(graph) {
|
||
return {
|
||
getActiveNodes(target, type = null) {
|
||
const source = target?.nodes || graph.nodes;
|
||
return source.filter(
|
||
(node) => !node.archived && (!type || node.type === type),
|
||
);
|
||
},
|
||
getNode(target, id) {
|
||
return (target?.nodes || graph.nodes).find((node) => node.id === id) || null;
|
||
},
|
||
getNodeEdges(target, nodeId) {
|
||
return (target?.edges || graph.edges).filter(
|
||
(edge) => edge.fromId === nodeId || edge.toId === nodeId,
|
||
);
|
||
},
|
||
buildTemporalAdjacencyMap() {
|
||
return new Map();
|
||
},
|
||
};
|
||
}
|
||
|
||
const schema = [{ id: "rule", label: "规则", alwaysInject: false }];
|
||
|
||
const state = {
|
||
vectorCalls: [],
|
||
diffusionCalls: [],
|
||
llmCalls: [],
|
||
llmCandidateCount: 0,
|
||
llmResponse: { selected_ids: ["rule-2", "rule-1"] },
|
||
llmOptions: [],
|
||
};
|
||
|
||
const graph = createGraph();
|
||
const helpers = createGraphHelpers(graph);
|
||
const retrieve = await loadRetrieve({
|
||
...helpers,
|
||
buildTaskPrompt() {
|
||
return { systemPrompt: "" };
|
||
},
|
||
applyTaskRegex(_settings, _taskType, _stage, text) {
|
||
return text;
|
||
},
|
||
splitIntentSegments(text) {
|
||
if (String(text).includes("和")) {
|
||
return String(text).split("和").map((item) => item.trim());
|
||
}
|
||
return [];
|
||
},
|
||
mergeVectorResults(groups, limit) {
|
||
const merged = new Map();
|
||
let rawHitCount = 0;
|
||
for (const group of groups) {
|
||
for (const item of group) {
|
||
rawHitCount += 1;
|
||
const existing = merged.get(item.nodeId);
|
||
if (!existing || item.score > existing.score) {
|
||
merged.set(item.nodeId, item);
|
||
}
|
||
}
|
||
}
|
||
return {
|
||
rawHitCount,
|
||
results: [...merged.values()].slice(0, limit),
|
||
};
|
||
},
|
||
collectSupplementalAnchorNodeIds() {
|
||
return [];
|
||
},
|
||
isEligibleAnchorNode(node) {
|
||
return Boolean(node?.fields?.title || node?.fields?.name);
|
||
},
|
||
createCooccurrenceIndex() {
|
||
return { map: new Map(), source: "batchJournal", batchCount: 0, pairCount: 0 };
|
||
},
|
||
applyCooccurrenceBoost(baseScores) {
|
||
return { scores: new Map(baseScores), boostedNodes: [] };
|
||
},
|
||
applyDiversitySampling(candidates, { k }) {
|
||
return {
|
||
applied: true,
|
||
reason: "",
|
||
selected: candidates.slice(0, k).reverse(),
|
||
beforeCount: candidates.length,
|
||
afterCount: Math.min(k, candidates.length),
|
||
};
|
||
},
|
||
async runResidualRecall() {
|
||
return { triggered: false, hits: [], skipReason: "residual-disabled-test" };
|
||
},
|
||
hybridScore: ({
|
||
graphScore = 0,
|
||
vectorScore = 0,
|
||
lexicalScore = 0,
|
||
importance = 0,
|
||
}) => graphScore + vectorScore + lexicalScore + importance,
|
||
reinforceAccessBatch() {},
|
||
validateVectorConfig() {
|
||
return { valid: true };
|
||
},
|
||
async findSimilarNodesByText(_graph, message, _embeddingConfig, topK) {
|
||
state.vectorCalls.push({ topK, message });
|
||
return [
|
||
{ nodeId: "rule-1", score: 0.9 },
|
||
{ nodeId: "rule-2", score: 0.8 },
|
||
{ nodeId: "rule-3", score: 0.7 },
|
||
];
|
||
},
|
||
diffuseAndRank(_adjacencyMap, seeds, options) {
|
||
state.diffusionCalls.push({ seeds, options });
|
||
return [
|
||
{ nodeId: "rule-2", energy: 1.2 },
|
||
{ nodeId: "rule-3", energy: 0.9 },
|
||
];
|
||
},
|
||
async callLLMForJSON(params = {}) {
|
||
const { userPrompt = "" } = params;
|
||
state.llmOptions.push({ ...params });
|
||
state.llmCalls.push(userPrompt);
|
||
state.llmCandidateCount = userPrompt
|
||
.split("\n")
|
||
.filter((line) => line.trim().startsWith("[")).length;
|
||
if (params.returnFailureDetails) {
|
||
if (state.llmResponse?.ok === false) {
|
||
return state.llmResponse;
|
||
}
|
||
return {
|
||
ok: true,
|
||
data: state.llmResponse,
|
||
errorType: "",
|
||
failureReason: "",
|
||
attempts: 1,
|
||
};
|
||
}
|
||
return state.llmResponse;
|
||
},
|
||
getSTContextForPrompt() {
|
||
return {};
|
||
},
|
||
});
|
||
|
||
state.vectorCalls.length = 0;
|
||
state.diffusionCalls.length = 0;
|
||
state.llmCalls.length = 0;
|
||
const noStageResult = await retrieve({
|
||
graph,
|
||
userMessage: "只看当前规则",
|
||
recentMessages: [],
|
||
embeddingConfig: {},
|
||
schema,
|
||
options: {
|
||
topK: 2,
|
||
maxRecallNodes: 2,
|
||
enableVectorPrefilter: false,
|
||
enableGraphDiffusion: false,
|
||
enableLLMRecall: false,
|
||
},
|
||
});
|
||
assert.equal(state.vectorCalls.length, 0);
|
||
assert.equal(state.diffusionCalls.length, 0);
|
||
assert.equal(state.llmCalls.length, 0);
|
||
assert.deepEqual(Array.from(noStageResult.selectedNodeIds), ["rule-2", "rule-1"]);
|
||
|
||
state.vectorCalls.length = 0;
|
||
await retrieve({
|
||
graph,
|
||
userMessage: "他后来怎么做?",
|
||
recentMessages: [
|
||
"[assistant]: 他提到了规则二的限制",
|
||
"[user]: 我们先看规则一",
|
||
"[user]: 他后来怎么做?",
|
||
],
|
||
embeddingConfig: {},
|
||
schema,
|
||
options: {
|
||
topK: 4,
|
||
maxRecallNodes: 2,
|
||
enableVectorPrefilter: true,
|
||
enableGraphDiffusion: false,
|
||
enableLLMRecall: false,
|
||
enableMultiIntent: false,
|
||
enableContextQueryBlend: true,
|
||
},
|
||
});
|
||
assert.deepEqual(
|
||
state.vectorCalls.map((item) => item.message),
|
||
["他后来怎么做?", "他提到了规则二的限制", "我们先看规则一"],
|
||
);
|
||
|
||
state.vectorCalls.length = 0;
|
||
state.diffusionCalls.length = 0;
|
||
state.llmCalls.length = 0;
|
||
state.llmOptions.length = 0;
|
||
state.llmCandidateCount = 0;
|
||
state.llmResponse = { selected_ids: ["rule-2", "rule-1"] };
|
||
const llmPoolResult = await retrieve({
|
||
graph,
|
||
userMessage: "请根据规则给出结论",
|
||
recentMessages: ["用户:现在该怎么做?"],
|
||
embeddingConfig: {},
|
||
schema,
|
||
options: {
|
||
topK: 4,
|
||
maxRecallNodes: 2,
|
||
enableVectorPrefilter: true,
|
||
enableGraphDiffusion: false,
|
||
enableLLMRecall: true,
|
||
llmCandidatePool: 2,
|
||
},
|
||
});
|
||
assert.deepEqual(state.vectorCalls, [
|
||
{ topK: 4, message: "请根据规则给出结论" },
|
||
{ topK: 4, message: "现在该怎么做?" },
|
||
]);
|
||
assert.equal(state.diffusionCalls.length, 0);
|
||
assert.equal(state.llmCandidateCount, 2);
|
||
assert.deepEqual(Array.from(llmPoolResult.selectedNodeIds), ["rule-2", "rule-1"]);
|
||
assert.equal(llmPoolResult.meta.retrieval.llm.status, "llm");
|
||
assert.equal(llmPoolResult.meta.retrieval.llm.candidatePool, 2);
|
||
assert.equal(llmPoolResult.meta.retrieval.vectorMergedHits, 3);
|
||
assert.equal(llmPoolResult.meta.retrieval.diversityApplied, true);
|
||
assert.equal(llmPoolResult.meta.retrieval.candidatePoolBeforeDpp, 3);
|
||
assert.equal(llmPoolResult.meta.retrieval.candidatePoolAfterDpp, 2);
|
||
assert.equal(state.llmOptions[0].returnFailureDetails, true);
|
||
assert.equal(state.llmOptions[0].maxRetries, 2);
|
||
assert.equal(state.llmOptions[0].maxCompletionTokens, 512);
|
||
|
||
state.vectorCalls.length = 0;
|
||
state.diffusionCalls.length = 0;
|
||
state.llmCalls.length = 0;
|
||
state.llmOptions.length = 0;
|
||
await retrieve({
|
||
graph,
|
||
userMessage: "规则一和规则二有什么关联",
|
||
recentMessages: [],
|
||
embeddingConfig: {},
|
||
schema,
|
||
options: {
|
||
topK: 3,
|
||
maxRecallNodes: 2,
|
||
enableVectorPrefilter: true,
|
||
enableGraphDiffusion: true,
|
||
diffusionTopK: 7,
|
||
enableLLMRecall: false,
|
||
enableMultiIntent: true,
|
||
multiIntentMaxSegments: 4,
|
||
enableTemporalLinks: true,
|
||
temporalLinkStrength: 0.2,
|
||
teleportAlpha: 0.15,
|
||
},
|
||
});
|
||
assert.equal(state.vectorCalls.length, 3);
|
||
assert.deepEqual(
|
||
state.vectorCalls.map((item) => item.topK),
|
||
[3, 3, 3],
|
||
);
|
||
assert.equal(state.diffusionCalls.length, 1);
|
||
assert.equal(state.diffusionCalls[0].options.topK, 7);
|
||
assert.equal(state.diffusionCalls[0].options.teleportAlpha, 0.15);
|
||
assert.equal(noStageResult.meta.retrieval.llm.status, "disabled");
|
||
|
||
state.vectorCalls.length = 0;
|
||
state.diffusionCalls.length = 0;
|
||
state.llmCalls.length = 0;
|
||
state.llmOptions.length = 0;
|
||
state.llmResponse = {
|
||
ok: false,
|
||
errorType: "invalid-json",
|
||
failureReason: "输出不是有效 JSON,请严格返回紧凑 JSON 对象",
|
||
};
|
||
const fallbackResult = await retrieve({
|
||
graph,
|
||
userMessage: "LLM 这次会坏掉",
|
||
recentMessages: ["用户:请回忆相关规则"],
|
||
embeddingConfig: {},
|
||
schema,
|
||
options: {
|
||
topK: 4,
|
||
maxRecallNodes: 2,
|
||
enableVectorPrefilter: true,
|
||
enableGraphDiffusion: false,
|
||
enableLLMRecall: true,
|
||
llmCandidatePool: 2,
|
||
},
|
||
});
|
||
assert.equal(fallbackResult.meta.retrieval.llm.status, "fallback");
|
||
assert.match(fallbackResult.meta.retrieval.llm.reason, /有效 JSON|回退到评分排序/);
|
||
assert.equal(fallbackResult.meta.retrieval.llm.fallbackType, "invalid-json");
|
||
|
||
const sceneGraph = {
|
||
nodes: [
|
||
{
|
||
id: "event-1",
|
||
type: "event",
|
||
importance: 10,
|
||
createdTime: 1,
|
||
archived: false,
|
||
fields: { title: "事件一" },
|
||
seqRange: [1, 1],
|
||
},
|
||
{
|
||
id: "character-1",
|
||
type: "character",
|
||
importance: 6,
|
||
createdTime: 2,
|
||
archived: false,
|
||
fields: { name: "Alice" },
|
||
seqRange: [1, 1],
|
||
},
|
||
{
|
||
id: "location-1",
|
||
type: "location",
|
||
importance: 5,
|
||
createdTime: 3,
|
||
archived: false,
|
||
fields: { title: "大厅" },
|
||
seqRange: [1, 1],
|
||
},
|
||
],
|
||
edges: [
|
||
{ fromId: "event-1", toId: "character-1", relation: "mentions" },
|
||
{ fromId: "event-1", toId: "location-1", relation: "occurs_at" },
|
||
],
|
||
};
|
||
const sceneSchema = [
|
||
{ id: "event", label: "事件", alwaysInject: false },
|
||
{ id: "character", label: "角色", alwaysInject: false },
|
||
{ id: "location", label: "地点", alwaysInject: false },
|
||
];
|
||
const cappedResult = await retrieve({
|
||
graph: sceneGraph,
|
||
userMessage: "只看这一个场景",
|
||
recentMessages: [],
|
||
embeddingConfig: {},
|
||
schema: sceneSchema,
|
||
options: {
|
||
topK: 3,
|
||
maxRecallNodes: 1,
|
||
enableVectorPrefilter: false,
|
||
enableGraphDiffusion: false,
|
||
enableLLMRecall: false,
|
||
enableProbRecall: false,
|
||
},
|
||
});
|
||
assert.equal(cappedResult.selectedNodeIds.length, 1);
|
||
|
||
const lexicalGraph = {
|
||
nodes: [
|
||
{
|
||
id: "char-1",
|
||
type: "character",
|
||
importance: 1,
|
||
createdTime: 1,
|
||
archived: false,
|
||
fields: { name: "Alice", summary: "常驻角色" },
|
||
seqRange: [1, 1],
|
||
},
|
||
{
|
||
id: "char-2",
|
||
type: "character",
|
||
importance: 1,
|
||
createdTime: 1,
|
||
archived: false,
|
||
fields: { name: "Bob", summary: "常驻角色" },
|
||
seqRange: [1, 1],
|
||
},
|
||
],
|
||
edges: [],
|
||
};
|
||
const lexicalSchema = [{ id: "character", label: "角色", alwaysInject: false }];
|
||
const lexicalResult = await retrieve({
|
||
graph: lexicalGraph,
|
||
userMessage: "Alice 现在怎么样了",
|
||
recentMessages: [],
|
||
embeddingConfig: {},
|
||
schema: lexicalSchema,
|
||
options: {
|
||
topK: 2,
|
||
maxRecallNodes: 1,
|
||
enableVectorPrefilter: false,
|
||
enableGraphDiffusion: false,
|
||
enableLLMRecall: false,
|
||
enableDiversitySampling: false,
|
||
enableLexicalBoost: true,
|
||
},
|
||
});
|
||
assert.deepEqual(Array.from(lexicalResult.selectedNodeIds), ["char-1"]);
|
||
assert.equal(lexicalResult.meta.retrieval.queryBlendActive, false);
|
||
assert.equal(lexicalResult.meta.retrieval.lexicalBoostedNodes, 1);
|
||
assert.equal(lexicalResult.meta.retrieval.lexicalTopHits[0]?.nodeId, "char-1");
|
||
|
||
console.log("retrieval-config tests passed");
|