Harden recall flow and JSON task prompts

This commit is contained in:
Youzini-afk
2026-03-28 20:38:57 +08:00
parent 30fdeaac1a
commit 67e6e29bb2
12 changed files with 618 additions and 200 deletions

View File

@@ -84,6 +84,8 @@ const state = {
diffusionCalls: [],
llmCalls: [],
llmCandidateCount: 0,
llmResponse: { selected_ids: ["rule-2", "rule-1"] },
llmOptions: [],
};
const graph = createGraph();
@@ -164,12 +166,26 @@ const retrieve = await loadRetrieve({
{ nodeId: "rule-3", energy: 0.9 },
];
},
async callLLMForJSON({ userPrompt }) {
async callLLMForJSON(params = {}) {
const { userPrompt = "" } = params;
state.llmOptions.push({ ...params });
state.llmCalls.push(userPrompt);
state.llmCandidateCount = userPrompt
.split("\n")
.filter((line) => line.trim().startsWith("[")).length;
return { selected_ids: ["rule-2", "rule-1"] };
if (params.returnFailureDetails) {
if (state.llmResponse?.ok === false) {
return state.llmResponse;
}
return {
ok: true,
data: state.llmResponse,
errorType: "",
failureReason: "",
attempts: 1,
};
}
return state.llmResponse;
},
getSTContextForPrompt() {
return {};
@@ -201,7 +217,9 @@ assert.deepEqual(Array.from(noStageResult.selectedNodeIds), ["rule-2", "rule-1"]
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
state.llmCalls.length = 0;
state.llmOptions.length = 0;
state.llmCandidateCount = 0;
state.llmResponse = { selected_ids: ["rule-2", "rule-1"] };
const llmPoolResult = await retrieve({
graph,
userMessage: "请根据规则给出结论",
@@ -227,10 +245,12 @@ assert.equal(llmPoolResult.meta.retrieval.vectorMergedHits, 3);
assert.equal(llmPoolResult.meta.retrieval.diversityApplied, true);
assert.equal(llmPoolResult.meta.retrieval.candidatePoolBeforeDpp, 3);
assert.equal(llmPoolResult.meta.retrieval.candidatePoolAfterDpp, 2);
assert.equal(state.llmOptions[0].returnFailureDetails, true);
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
state.llmCalls.length = 0;
state.llmOptions.length = 0;
await retrieve({
graph,
userMessage: "规则一和规则二有什么关联",
@@ -261,4 +281,89 @@ assert.equal(state.diffusionCalls[0].options.topK, 7);
assert.equal(state.diffusionCalls[0].options.teleportAlpha, 0.15);
assert.equal(noStageResult.meta.retrieval.llm.status, "disabled");
state.vectorCalls.length = 0;
state.diffusionCalls.length = 0;
state.llmCalls.length = 0;
state.llmOptions.length = 0;
state.llmResponse = {
ok: false,
errorType: "invalid-json",
failureReason: "输出不是有效 JSON请严格返回紧凑 JSON 对象",
};
const fallbackResult = await retrieve({
graph,
userMessage: "LLM 这次会坏掉",
recentMessages: ["用户:请回忆相关规则"],
embeddingConfig: {},
schema,
options: {
topK: 4,
maxRecallNodes: 2,
enableVectorPrefilter: true,
enableGraphDiffusion: false,
enableLLMRecall: true,
llmCandidatePool: 2,
},
});
assert.equal(fallbackResult.meta.retrieval.llm.status, "fallback");
assert.match(fallbackResult.meta.retrieval.llm.reason, /有效 JSON|回退到评分排序/);
assert.equal(fallbackResult.meta.retrieval.llm.fallbackType, "invalid-json");
const sceneGraph = {
nodes: [
{
id: "event-1",
type: "event",
importance: 10,
createdTime: 1,
archived: false,
fields: { title: "事件一" },
seqRange: [1, 1],
},
{
id: "character-1",
type: "character",
importance: 6,
createdTime: 2,
archived: false,
fields: { name: "Alice" },
seqRange: [1, 1],
},
{
id: "location-1",
type: "location",
importance: 5,
createdTime: 3,
archived: false,
fields: { title: "大厅" },
seqRange: [1, 1],
},
],
edges: [
{ fromId: "event-1", toId: "character-1", relation: "mentions" },
{ fromId: "event-1", toId: "location-1", relation: "occurs_at" },
],
};
const sceneSchema = [
{ id: "event", label: "事件", alwaysInject: false },
{ id: "character", label: "角色", alwaysInject: false },
{ id: "location", label: "地点", alwaysInject: false },
];
const cappedResult = await retrieve({
graph: sceneGraph,
userMessage: "只看这一个场景",
recentMessages: [],
embeddingConfig: {},
schema: sceneSchema,
options: {
topK: 3,
maxRecallNodes: 1,
enableVectorPrefilter: false,
enableGraphDiffusion: false,
enableLLMRecall: false,
enableProbRecall: false,
},
});
assert.equal(cappedResult.selectedNodeIds.length, 1);
console.log("retrieval-config tests passed");