Harden recall flow and JSON task prompts

This commit is contained in:
Youzini-afk
2026-03-28 20:38:57 +08:00
parent 30fdeaac1a
commit 67e6e29bb2
12 changed files with 618 additions and 200 deletions

View File

@@ -28,7 +28,7 @@ function extractSnippet(startMarker, endMarker) {
const persistencePrelude = extractSnippet(
'const MODULE_NAME = "st_bme";',
"function clearInjectionState() {",
"function clearInjectionState(options = {}) {",
);
const persistenceCore = extractSnippet(
"function loadGraphFromChat(options = {}) {",

View File

@@ -237,7 +237,7 @@ function createGenerationRecallHarness() {
);
context.runRecall = async (options = {}) => {
context.runRecallCalls.push({ ...options });
return true;
return { status: "completed", didRecall: true, ok: true };
};
return context;
});
@@ -1296,6 +1296,31 @@ async function testGenerationRecallDifferentKeyCanRunAgain() {
);
}
async function testGenerationRecallSkippedStateDoesNotLoopToBeforeCombine() {
const harness = await createGenerationRecallHarness();
harness.chat = [{ is_user: true, mes: "同一条但本次跳过" }];
harness.runRecall = async (options = {}) => {
harness.runRecallCalls.push({ ...options });
return {
status: "skipped",
didRecall: false,
ok: false,
reason: "测试跳过",
};
};
await harness.result.onGenerationAfterCommands("normal", {}, false);
await harness.result.onBeforeCombinePrompts();
assert.equal(harness.runRecallCalls.length, 1);
assert.equal(
harness.result.generationRecallTransactions.size,
1,
);
const transaction = [...harness.result.generationRecallTransactions.values()][0];
assert.equal(transaction.hookStates.GENERATION_AFTER_COMMANDS, "skipped");
}
async function testRerollUsesBatchBoundaryRollbackAndPersistsState() {
const harness = await createRerollHarness();
harness.chat = [
@@ -1644,6 +1669,7 @@ await testProcessedHistoryAdvanceRequiresCompleteStrongSuccess();
await testGenerationRecallTransactionDedupesDoubleHookBySameKey();
await testGenerationRecallBeforeCombineRunsStandalone();
await testGenerationRecallDifferentKeyCanRunAgain();
await testGenerationRecallSkippedStateDoesNotLoopToBeforeCombine();
await testRerollUsesBatchBoundaryRollbackAndPersistsState();
await testRerollRejectsMissingRecoveryPoint();
await testRerollFallsBackToDirectExtractForUnprocessedFloor();

View File

@@ -51,6 +51,7 @@ const extractPromptBuild = await buildTaskPrompt(settings, "extract", {
currentRange: "1 ~ 2",
});
const extractPayload = buildTaskLlmPayload(extractPromptBuild, "fallback-user");
assert.equal(extractPayload.systemPrompt, "");
assert.equal(extractPayload.userPrompt, "");
assert.equal(
extractPayload.promptMessages.filter((message) => message.role === "user").length,
@@ -86,6 +87,7 @@ const recallPromptBuild = await buildTaskPrompt(settings, "recall", {
graphStats: "candidate_count=2",
});
const recallPayload = buildTaskLlmPayload(recallPromptBuild, "fallback-user");
assert.equal(recallPayload.systemPrompt, "");
assert.equal(recallPayload.userPrompt, "");
assert.equal(
recallPayload.promptMessages.filter((message) => message.role === "user").length,

View File

@@ -258,6 +258,7 @@ try {
};
const payload = buildTaskLlmPayload(promptBuild, "unused fallback");
assert.equal(payload.systemPrompt, "");
const result = await llm.callLLMForJSON({
systemPrompt: payload.systemPrompt,
userPrompt: payload.userPrompt,

View File

@@ -84,6 +84,8 @@ const state = {
diffusionCalls: [],
llmCalls: [],
llmCandidateCount: 0,
llmResponse: { selected_ids: ["rule-2", "rule-1"] },
llmOptions: [],
};
const graph = createGraph();
@@ -164,12 +166,26 @@ const retrieve = await loadRetrieve({
{ nodeId: "rule-3", energy: 0.9 },
];
},
async callLLMForJSON({ userPrompt }) {
async callLLMForJSON(params = {}) {
const { userPrompt = "" } = params;
state.llmOptions.push({ ...params });
state.llmCalls.push(userPrompt);
state.llmCandidateCount = userPrompt
.split("\n")
.filter((line) => line.trim().startsWith("[")).length;
return { selected_ids: ["rule-2", "rule-1"] };
if (params.returnFailureDetails) {
if (state.llmResponse?.ok === false) {
return state.llmResponse;
}
return {
ok: true,
data: state.llmResponse,
errorType: "",
failureReason: "",
attempts: 1,
};
}
return state.llmResponse;
},
getSTContextForPrompt() {
return {};
@@ -201,7 +217,9 @@ assert.deepEqual(Array.from(noStageResult.selectedNodeIds), ["rule-2", "rule-1"]
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
state.llmCalls.length = 0;
state.llmOptions.length = 0;
state.llmCandidateCount = 0;
state.llmResponse = { selected_ids: ["rule-2", "rule-1"] };
const llmPoolResult = await retrieve({
graph,
userMessage: "请根据规则给出结论",
@@ -227,10 +245,12 @@ assert.equal(llmPoolResult.meta.retrieval.vectorMergedHits, 3);
assert.equal(llmPoolResult.meta.retrieval.diversityApplied, true);
assert.equal(llmPoolResult.meta.retrieval.candidatePoolBeforeDpp, 3);
assert.equal(llmPoolResult.meta.retrieval.candidatePoolAfterDpp, 2);
assert.equal(state.llmOptions[0].returnFailureDetails, true);
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
state.llmCalls.length = 0;
state.llmOptions.length = 0;
await retrieve({
graph,
userMessage: "规则一和规则二有什么关联",
@@ -261,4 +281,89 @@ assert.equal(state.diffusionCalls[0].options.topK, 7);
assert.equal(state.diffusionCalls[0].options.teleportAlpha, 0.15);
assert.equal(noStageResult.meta.retrieval.llm.status, "disabled");
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
state.llmCalls.length = 0;
state.llmOptions.length = 0;
state.llmResponse = {
ok: false,
errorType: "invalid-json",
failureReason: "输出不是有效 JSON请严格返回紧凑 JSON 对象",
};
const fallbackResult = await retrieve({
graph,
userMessage: "LLM 这次会坏掉",
recentMessages: ["用户:请回忆相关规则"],
embeddingConfig: {},
schema,
options: {
topK: 4,
maxRecallNodes: 2,
enableVectorPrefilter: true,
enableGraphDiffusion: false,
enableLLMRecall: true,
llmCandidatePool: 2,
},
});
assert.equal(fallbackResult.meta.retrieval.llm.status, "fallback");
assert.match(fallbackResult.meta.retrieval.llm.reason, /有效 JSON|回退到评分排序/);
assert.equal(fallbackResult.meta.retrieval.llm.fallbackType, "invalid-json");
const sceneGraph = {
nodes: [
{
id: "event-1",
type: "event",
importance: 10,
createdTime: 1,
archived: false,
fields: { title: "事件一" },
seqRange: [1, 1],
},
{
id: "character-1",
type: "character",
importance: 6,
createdTime: 2,
archived: false,
fields: { name: "Alice" },
seqRange: [1, 1],
},
{
id: "location-1",
type: "location",
importance: 5,
createdTime: 3,
archived: false,
fields: { title: "大厅" },
seqRange: [1, 1],
},
],
edges: [
{ fromId: "event-1", toId: "character-1", relation: "mentions" },
{ fromId: "event-1", toId: "location-1", relation: "occurs_at" },
],
};
const sceneSchema = [
{ id: "event", label: "事件", alwaysInject: false },
{ id: "character", label: "角色", alwaysInject: false },
{ id: "location", label: "地点", alwaysInject: false },
];
const cappedResult = await retrieve({
graph: sceneGraph,
userMessage: "只看这一个场景",
recentMessages: [],
embeddingConfig: {},
schema: sceneSchema,
options: {
topK: 3,
maxRecallNodes: 1,
enableVectorPrefilter: false,
enableGraphDiffusion: false,
enableLLMRecall: false,
enableProbRecall: false,
},
});
assert.equal(cappedResult.selectedNodeIds.length, 1);
console.log("retrieval-config tests passed");