mirror of
https://github.com/Youzini-afk/ST-Bionic-Memory-Ecology.git
synced 2026-05-15 22:30:38 +08:00
Reorganize modules into layered directories
This commit is contained in:
205
retrieval/diffusion.js
Normal file
205
retrieval/diffusion.js
Normal file
@@ -0,0 +1,205 @@
|
||||
// ST-BME: JS 版 PEDSA 扩散激活引擎
|
||||
// 从 PeroCore 的 Rust CognitiveGraphEngine 移植核心算法到纯 JS
|
||||
// 适配 ST 场景(<1万节点,不需要并行/SIMD)
|
||||
|
||||
/**
|
||||
* PEDSA 扩散激活引擎
|
||||
*
|
||||
* 算法:Parallel Energy-Decay Spreading Activation
|
||||
* 本质:在有向加权图上的能量传播模型
|
||||
*
|
||||
* 核心公式:
|
||||
* E_{t+1}(j) = Σ_{i∈N(j)} E_t(i) × W_ij × D_decay
|
||||
*
|
||||
* 特点(保留自 PeroCore):
|
||||
* - 能量衰减:每步传播乘以衰减因子
|
||||
* - 动态剪枝:每步只保留 Top-K 活跃节点
|
||||
* - 抑制机制:特殊边类型传递负能量
|
||||
* - 能量钳位:限制在 [-2.0, 2.0] 范围
|
||||
*
|
||||
* 与 PeroCore Rust 版的差异:
|
||||
* - 无 Rayon 并行(JS 单线程,ST 场景不需要)
|
||||
* - 无 u16 量化(直接 f64,内存不是瓶颈)
|
||||
* - 无 SIMD(普通数组运算)
|
||||
*/
|
||||
|
||||
/**
|
||||
* 抑制边类型标记
|
||||
*/
|
||||
const INHIBIT_EDGE_TYPE = 255;
|
||||
|
||||
/**
|
||||
* 默认配置
|
||||
*/
|
||||
const DEFAULT_OPTIONS = {
|
||||
maxSteps: 2, // 最大扩散步数
|
||||
decayFactor: 0.6, // 每步衰减因子
|
||||
topK: 100, // 每步保留的最大活跃节点数
|
||||
minEnergy: 0.01, // 最小有效能量(低于此值视为不活跃)
|
||||
maxEnergy: 2.0, // 能量上限
|
||||
minEnergy_clamp: -2.0, // 能量下限(抑制)
|
||||
teleportAlpha: 0.0, // PPR 回拉概率
|
||||
inhibitMultiplier: 2.0, // 抑制边负向传播倍率
|
||||
};
|
||||
|
||||
/**
|
||||
* 执行 PEDSA 扩散激活
|
||||
*
|
||||
* @param {Map<string, Array<{targetId: string, strength: number, edgeType: number}>>} adjacencyMap
|
||||
* 邻接表:nodeId → [{targetId, strength, edgeType}]
|
||||
* 可通过 graph.buildAdjacencyMap() 构建
|
||||
*
|
||||
* @param {Array<{id: string, energy: number}>} seedNodes
|
||||
* 初始种子节点及其能量
|
||||
* - 向量检索命中的节点:energy = vectorScore (0~1)
|
||||
* - 实体锚点节点:energy = 2.0(最大值)
|
||||
*
|
||||
* @param {object} [options] - 配置选项
|
||||
*
|
||||
* @returns {Map<string, number>} 所有被激活节点的最终能量
|
||||
* nodeId → energy(正值=激活,负值=抑制)
|
||||
*/
|
||||
export function propagateActivation(adjacencyMap, seedNodes, options = {}) {
|
||||
const opts = { ...DEFAULT_OPTIONS, ...options };
|
||||
const teleportAlpha = clamp01(opts.teleportAlpha);
|
||||
|
||||
/** @type {Map<string, number>} */
|
||||
let currentEnergy = new Map();
|
||||
/** @type {Map<string, number>} */
|
||||
const initialEnergy = new Map();
|
||||
|
||||
for (const seed of seedNodes || []) {
|
||||
if (!seed?.id) continue;
|
||||
const clamped = clampEnergy(Number(seed.energy) || 0, opts);
|
||||
if (Math.abs(clamped) >= opts.minEnergy) {
|
||||
const existing = currentEnergy.get(seed.id) || 0;
|
||||
const next = clampEnergy(existing + clamped, opts);
|
||||
currentEnergy.set(seed.id, next);
|
||||
initialEnergy.set(seed.id, next);
|
||||
}
|
||||
}
|
||||
|
||||
// 累积结果(所有步骤的最大能量)
|
||||
/** @type {Map<string, number>} */
|
||||
const result = new Map(currentEnergy);
|
||||
|
||||
// Step 1~N: 逐步扩散
|
||||
for (let step = 0; step < opts.maxSteps; step++) {
|
||||
/** @type {Map<string, number>} */
|
||||
const nextEnergy = new Map();
|
||||
|
||||
// 对每个当前活跃节点,传播能量到邻居
|
||||
for (const [nodeId, energy] of currentEnergy) {
|
||||
const neighbors = adjacencyMap.get(nodeId);
|
||||
if (!Array.isArray(neighbors) || neighbors.length === 0) continue;
|
||||
|
||||
for (const neighbor of neighbors) {
|
||||
if (!neighbor?.targetId) continue;
|
||||
let propagated =
|
||||
energy *
|
||||
(Number(neighbor.strength) || 0) *
|
||||
opts.decayFactor *
|
||||
(1 - teleportAlpha);
|
||||
|
||||
// 抑制边:传递负能量
|
||||
if (neighbor.edgeType === INHIBIT_EDGE_TYPE) {
|
||||
propagated =
|
||||
-Math.abs(energy) *
|
||||
(Number(neighbor.strength) || 0) *
|
||||
opts.decayFactor *
|
||||
(Number(opts.inhibitMultiplier) || 1);
|
||||
}
|
||||
|
||||
// 累加到邻居节点
|
||||
const existing = nextEnergy.get(neighbor.targetId) || 0;
|
||||
nextEnergy.set(neighbor.targetId, existing + propagated);
|
||||
}
|
||||
}
|
||||
|
||||
// 钳位 + 过滤低能量
|
||||
for (const [nodeId, energy] of nextEnergy) {
|
||||
const clamped = clampEnergy(energy, opts);
|
||||
if (Math.abs(clamped) < opts.minEnergy) {
|
||||
nextEnergy.delete(nodeId);
|
||||
} else {
|
||||
nextEnergy.set(nodeId, clamped);
|
||||
}
|
||||
}
|
||||
|
||||
if (teleportAlpha > 0) {
|
||||
for (const [nodeId, seedEnergy] of initialEnergy) {
|
||||
const current = nextEnergy.get(nodeId) || 0;
|
||||
const teleported =
|
||||
(1 - teleportAlpha) * current + teleportAlpha * seedEnergy;
|
||||
const clamped = clampEnergy(teleported, opts);
|
||||
if (Math.abs(clamped) >= opts.minEnergy) {
|
||||
nextEnergy.set(nodeId, clamped);
|
||||
} else {
|
||||
nextEnergy.delete(nodeId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 动态剪枝:只保留 Top-K
|
||||
if (nextEnergy.size > opts.topK) {
|
||||
const sorted = [...nextEnergy.entries()].sort(
|
||||
(a, b) => Math.abs(b[1]) - Math.abs(a[1]),
|
||||
);
|
||||
|
||||
nextEnergy.clear();
|
||||
for (let i = 0; i < opts.topK && i < sorted.length; i++) {
|
||||
nextEnergy.set(sorted[i][0], sorted[i][1]);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新累积结果(取各步骤最大绝对值能量)
|
||||
for (const [nodeId, energy] of nextEnergy) {
|
||||
const existing = result.get(nodeId) || 0;
|
||||
if (Math.abs(energy) > Math.abs(existing)) {
|
||||
result.set(nodeId, energy);
|
||||
}
|
||||
}
|
||||
|
||||
// 准备下一步
|
||||
currentEnergy = nextEnergy;
|
||||
|
||||
// 如果没有活跃节点了,提前终止
|
||||
if (currentEnergy.size === 0) break;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 能量钳位
|
||||
* @param {number} energy
|
||||
* @param {object} opts
|
||||
* @returns {number}
|
||||
*/
|
||||
function clampEnergy(energy, opts) {
|
||||
return Math.max(opts.minEnergy_clamp, Math.min(opts.maxEnergy, energy));
|
||||
}
|
||||
|
||||
function clamp01(value) {
|
||||
return Math.max(0, Math.min(1, Number(value) || 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* 快捷方法:从种子列表创建扩散并返回按能量排序的结果
|
||||
*
|
||||
* @param {Map} adjacencyMap - 邻接表
|
||||
* @param {Array<{id: string, energy: number}>} seeds - 种子节点
|
||||
* @param {object} [options]
|
||||
* @returns {Array<{nodeId: string, energy: number}>} 按能量降序排列
|
||||
*/
|
||||
export function diffuseAndRank(adjacencyMap, seeds, options = {}) {
|
||||
const energyMap = propagateActivation(adjacencyMap, seeds, options);
|
||||
|
||||
return [...energyMap.entries()]
|
||||
.filter(([_, energy]) => energy > 0)
|
||||
.map(([nodeId, energy]) => ({ nodeId, energy }))
|
||||
.sort((a, b) => {
|
||||
if (b.energy !== a.energy) return b.energy - a.energy;
|
||||
return String(a.nodeId).localeCompare(String(b.nodeId));
|
||||
});
|
||||
}
|
||||
114
retrieval/dynamics.js
Normal file
114
retrieval/dynamics.js
Normal file
@@ -0,0 +1,114 @@
|
||||
// ST-BME: 记忆动力学模块
|
||||
// 实现访问强化、时间衰减、混合评分 — 来自 PeroCore 的核心创新
|
||||
|
||||
/**
|
||||
* 访问强化:节点被召回/注入时调用
|
||||
* - accessCount += 1
|
||||
* - importance += 0.1(上限 10)
|
||||
* - lastAccessTime 更新
|
||||
*
|
||||
* @param {object} node
|
||||
*/
|
||||
export function reinforceAccess(node) {
|
||||
node.accessCount = (node.accessCount || 0) + 1;
|
||||
node.importance = Math.min(10, (node.importance || 5) + 0.1);
|
||||
node.lastAccessTime = Date.now();
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算时间衰减因子
|
||||
* 使用对数衰减(PeroCore 方式)而非指数衰减:
|
||||
* factor = 0.8 + 0.2 / (1 + ln(1 + Δt_days))
|
||||
*
|
||||
* 特点:久远但重要的记忆不会快速消失
|
||||
* - Δt = 0天 → factor = 1.0
|
||||
* - Δt = 1天 → factor ≈ 0.93
|
||||
* - Δt = 7天 → factor ≈ 0.89
|
||||
* - Δt = 30天 → factor ≈ 0.85
|
||||
* - Δt = 365天 → factor ≈ 0.83
|
||||
*
|
||||
* @param {number} createdTime - 创建时间戳(ms)
|
||||
* @param {number} [now] - 当前时间戳(ms)
|
||||
* @returns {number} 衰减因子 [0.8, 1.0]
|
||||
*/
|
||||
export function timeDecayFactor(createdTime, now = Date.now()) {
|
||||
const deltaDays = Math.max(0, (now - createdTime) / (1000 * 60 * 60 * 24));
|
||||
return 0.8 + 0.2 / (1 + Math.log(1 + deltaDays));
|
||||
}
|
||||
|
||||
/**
|
||||
* 混合评分公式
|
||||
* FinalScore = (GraphScore×α + VecScore×β + ImportanceNorm×γ) × TimeDecay
|
||||
*
|
||||
* 默认权重:α=0.6, β=0.3, γ=0.1
|
||||
*
|
||||
* @param {object} params
|
||||
* @param {number} params.graphScore - 图扩散能量得分 [0, 2]
|
||||
* @param {number} params.vectorScore - 向量相似度 [0, 1]
|
||||
* @param {number} params.importance - 节点重要性 [0, 10]
|
||||
* @param {number} params.createdTime - 节点创建时间
|
||||
* @param {object} [weights] - 权重配置
|
||||
* @returns {number} 最终得分
|
||||
*/
|
||||
export function hybridScore({
|
||||
graphScore = 0,
|
||||
vectorScore = 0,
|
||||
lexicalScore = 0,
|
||||
importance = 5,
|
||||
createdTime = Date.now(),
|
||||
}, weights = {}) {
|
||||
const alpha = weights.graphWeight ?? 0.6;
|
||||
const beta = weights.vectorWeight ?? 0.3;
|
||||
const gamma = weights.importanceWeight ?? 0.1;
|
||||
const delta = weights.lexicalWeight ?? 0;
|
||||
|
||||
// 归一化
|
||||
const normGraph = Math.max(0, Math.min(1, graphScore / 2.0)); // PEDSA 能量范围 [-2, 2] → [0, 1]
|
||||
const normVec = Math.max(0, Math.min(1, vectorScore));
|
||||
const normLexical = Math.max(0, Math.min(1, lexicalScore));
|
||||
const normImportance = Math.max(0, Math.min(1, importance / 10.0));
|
||||
const totalWeight = Math.max(
|
||||
1e-6,
|
||||
Math.max(0, alpha) + Math.max(0, beta) + Math.max(0, gamma) + Math.max(0, delta),
|
||||
);
|
||||
|
||||
const baseScore =
|
||||
(normGraph * alpha +
|
||||
normVec * beta +
|
||||
normLexical * delta +
|
||||
normImportance * gamma) /
|
||||
totalWeight;
|
||||
const decay = timeDecayFactor(createdTime);
|
||||
|
||||
return baseScore * decay;
|
||||
}
|
||||
|
||||
/**
|
||||
* 边权衰减:长期未被激活的边降低强度
|
||||
* 只降低到最低 0.1,不会归零
|
||||
*
|
||||
* @param {object[]} edges
|
||||
* @param {Set<string>} activatedEdgeIds - 最近被激活(出现在扩散路径上)的边 ID
|
||||
* @param {number} [decayRate=0.02] - 每次调用的衰减量
|
||||
*/
|
||||
export function decayEdgeWeights(edges, activatedEdgeIds = new Set(), decayRate = 0.02) {
|
||||
for (const edge of edges) {
|
||||
if (activatedEdgeIds.has(edge.id)) {
|
||||
// 被激活的边轻微加强
|
||||
edge.strength = Math.min(1.0, edge.strength + decayRate * 0.5);
|
||||
} else {
|
||||
// 未被激活的边轻微衰减
|
||||
edge.strength = Math.max(0.1, edge.strength - decayRate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量对选中节点执行访问强化
|
||||
* @param {object[]} nodes - 被召回的节点列表
|
||||
*/
|
||||
export function reinforceAccessBatch(nodes) {
|
||||
for (const node of nodes) {
|
||||
reinforceAccess(node);
|
||||
}
|
||||
}
|
||||
240
retrieval/injector.js
Normal file
240
retrieval/injector.js
Normal file
@@ -0,0 +1,240 @@
|
||||
// ST-BME: Prompt 注入模块
|
||||
// 将检索结果格式化为表格注入到 LLM 上下文中
|
||||
|
||||
import { getSchemaType } from "../graph/schema.js";
|
||||
import { normalizeMemoryScope } from "../graph/memory-scope.js";
|
||||
|
||||
/**
|
||||
* 将检索结果转换为注入文本
|
||||
*
|
||||
* @param {object} retrievalResult - retriever.retrieve() 的返回值
|
||||
* @param {object[]} schema - 节点类型 Schema
|
||||
* @returns {string} 注入文本
|
||||
*/
|
||||
export function formatInjection(retrievalResult, schema) {
|
||||
const { coreNodes, recallNodes, groupedRecallNodes, scopeBuckets } =
|
||||
retrievalResult;
|
||||
const parts = [];
|
||||
const appended = new Set();
|
||||
|
||||
if (scopeBuckets && typeof scopeBuckets === "object") {
|
||||
appendScopeSection(
|
||||
parts,
|
||||
"[Memory - Character POV]",
|
||||
scopeBuckets.characterPov,
|
||||
schema,
|
||||
appended,
|
||||
);
|
||||
appendScopeSection(
|
||||
parts,
|
||||
"[Memory - User POV / Not Character Facts]",
|
||||
scopeBuckets.userPov,
|
||||
schema,
|
||||
appended,
|
||||
"这些是用户/玩家侧主观记忆,不等于角色已知事实;只能作为关系、承诺、情绪和长期互动背景参考。",
|
||||
);
|
||||
appendScopeSection(
|
||||
parts,
|
||||
"[Memory - Objective / Current Region]",
|
||||
scopeBuckets.objectiveCurrentRegion,
|
||||
schema,
|
||||
appended,
|
||||
);
|
||||
appendScopeSection(
|
||||
parts,
|
||||
"[Memory - Objective / Global]",
|
||||
scopeBuckets.objectiveGlobal,
|
||||
schema,
|
||||
appended,
|
||||
);
|
||||
|
||||
if (parts.length > 0) {
|
||||
return parts.join("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Core 常驻注入 ==========
|
||||
if (coreNodes.length > 0) {
|
||||
parts.push("[Memory - Core]");
|
||||
|
||||
const grouped = groupByType(coreNodes);
|
||||
|
||||
for (const [typeId, nodes] of grouped) {
|
||||
const typeDef = getSchemaType(schema, typeId);
|
||||
if (!typeDef) continue;
|
||||
|
||||
const table = formatTable(nodes, typeDef, appended);
|
||||
if (table) parts.push(table);
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Recall 召回注入 ==========
|
||||
if (recallNodes.length > 0) {
|
||||
parts.push("");
|
||||
parts.push("[Memory - Recalled]");
|
||||
|
||||
const buckets = groupedRecallNodes || {
|
||||
state: recallNodes.filter(
|
||||
(n) => n.type === "character" || n.type === "location",
|
||||
),
|
||||
episodic: recallNodes.filter(
|
||||
(n) => n.type === "event" || n.type === "thread",
|
||||
),
|
||||
reflective: recallNodes.filter(
|
||||
(n) => n.type === "reflection" || n.type === "synopsis",
|
||||
),
|
||||
rule: recallNodes.filter((n) => n.type === "rule"),
|
||||
other: recallNodes.filter(
|
||||
(n) =>
|
||||
![
|
||||
"character",
|
||||
"location",
|
||||
"event",
|
||||
"thread",
|
||||
"reflection",
|
||||
"synopsis",
|
||||
"rule",
|
||||
].includes(n.type),
|
||||
),
|
||||
};
|
||||
|
||||
appendBucket(parts, "当前状态记忆", buckets.state, schema, appended);
|
||||
appendBucket(parts, "情景事件记忆", buckets.episodic, schema, appended);
|
||||
appendBucket(parts, "反思与长期锚点", buckets.reflective, schema, appended);
|
||||
appendBucket(parts, "规则与约束", buckets.rule, schema, appended);
|
||||
appendBucket(parts, "其他关联记忆", buckets.other, schema, appended);
|
||||
}
|
||||
|
||||
return parts.join("\n");
|
||||
}
|
||||
|
||||
function appendScopeSection(parts, title, nodes, schema, appended, note = "") {
|
||||
if (!Array.isArray(nodes) || nodes.length === 0) return;
|
||||
if (parts.length > 0) {
|
||||
parts.push("");
|
||||
}
|
||||
parts.push(title);
|
||||
if (note) {
|
||||
parts.push(note);
|
||||
}
|
||||
|
||||
const grouped = groupByType(nodes);
|
||||
for (const [typeId, groupedNodes] of grouped) {
|
||||
const typeDef = getSchemaType(schema, typeId);
|
||||
if (!typeDef) continue;
|
||||
const table = formatTable(groupedNodes, typeDef, appended);
|
||||
if (table) parts.push(table);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 按类型分组节点
|
||||
*/
|
||||
function groupByType(nodes) {
|
||||
const map = new Map();
|
||||
for (const node of nodes) {
|
||||
if (!map.has(node.type)) map.set(node.type, []);
|
||||
map.get(node.type).push(node);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
function appendBucket(parts, title, nodes, schema, appended) {
|
||||
if (!nodes || nodes.length === 0) return;
|
||||
parts.push(`## ${title}`);
|
||||
|
||||
const grouped = groupByType(nodes);
|
||||
for (const [typeId, groupedNodes] of grouped) {
|
||||
const typeDef = getSchemaType(schema, typeId);
|
||||
if (!typeDef) continue;
|
||||
|
||||
const table = formatTable(groupedNodes, typeDef, appended);
|
||||
if (table) parts.push(table);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将同类型节点格式化为 Markdown 表格
|
||||
*/
|
||||
function formatTable(nodes, typeDef, appended = new Set()) {
|
||||
if (!Array.isArray(nodes) || nodes.length === 0) return "";
|
||||
|
||||
const uniqueNodes = nodes.filter((node) => {
|
||||
if (!node?.id || appended.has(node.id)) return false;
|
||||
appended.add(node.id);
|
||||
return true;
|
||||
});
|
||||
|
||||
if (uniqueNodes.length === 0) return "";
|
||||
|
||||
// 确定要展示的列(有实际数据的列)
|
||||
const activeCols = typeDef.columns.filter((col) =>
|
||||
uniqueNodes.some(
|
||||
(n) => n.fields?.[col.name] != null && n.fields[col.name] !== "",
|
||||
),
|
||||
);
|
||||
const derivedCols = buildDerivedColumns(uniqueNodes, typeDef);
|
||||
const allCols = [...derivedCols, ...activeCols];
|
||||
|
||||
if (allCols.length === 0) return "";
|
||||
|
||||
// 表头
|
||||
const header = `| ${allCols.map((c) => c.name).join(" | ")} |`;
|
||||
const separator = `| ${allCols.map(() => "---").join(" | ")} |`;
|
||||
|
||||
// 数据行
|
||||
const rows = uniqueNodes.map((node) => {
|
||||
const cells = allCols.map((col) => {
|
||||
const val =
|
||||
typeof col.getValue === "function"
|
||||
? col.getValue(node)
|
||||
: node.fields?.[col.name] ?? "";
|
||||
// 转义管道符,限制单元格长度
|
||||
return String(val)
|
||||
.replace(/\|/g, "\\|")
|
||||
.replace(/\n/g, " ")
|
||||
.slice(0, 200);
|
||||
});
|
||||
return `| ${cells.join(" | ")} |`;
|
||||
});
|
||||
|
||||
return `${typeDef.tableName}:\n${header}\n${separator}\n${rows.join("\n")}`;
|
||||
}
|
||||
|
||||
function buildDerivedColumns(nodes, typeDef) {
|
||||
if (typeDef?.id !== "pov_memory") {
|
||||
return [];
|
||||
}
|
||||
|
||||
return [
|
||||
{
|
||||
name: "owner",
|
||||
getValue(node) {
|
||||
const scope = normalizeMemoryScope(node?.scope);
|
||||
const ownerLabel = scope.ownerName || scope.ownerId || "未命名";
|
||||
if (scope.ownerType === "user") {
|
||||
return `用户: ${ownerLabel}`;
|
||||
}
|
||||
if (scope.ownerType === "character") {
|
||||
return `角色: ${ownerLabel}`;
|
||||
}
|
||||
return `POV: ${ownerLabel}`;
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取注入提示词的总 token 估算
|
||||
* 粗略估算:1 个 token ≈ 2 个中文字符 或 4 个英文字符
|
||||
*
|
||||
* @param {string} injectionText
|
||||
* @returns {number} 估算 token 数
|
||||
*/
|
||||
export function estimateTokens(injectionText) {
|
||||
if (!injectionText) return 0;
|
||||
// 简单估算:中文 2 字符/token,英文 4 字符/token
|
||||
const cnChars = (injectionText.match(/[\u4e00-\u9fff]/g) || []).length;
|
||||
const otherChars = injectionText.length - cnChars;
|
||||
return Math.ceil(cnChars / 2 + otherChars / 4);
|
||||
}
|
||||
566
retrieval/recall-controller.js
Normal file
566
retrieval/recall-controller.js
Normal file
@@ -0,0 +1,566 @@
|
||||
// ST-BME: 召回输入解析与注入控制器(纯函数)
|
||||
|
||||
import { debugLog } from "../runtime/debug-logging.js";
|
||||
|
||||
export function buildRecallRecentMessagesController(
|
||||
chat,
|
||||
limit,
|
||||
syntheticUserMessage = "",
|
||||
runtime,
|
||||
) {
|
||||
if (!Array.isArray(chat) || limit <= 0) return [];
|
||||
|
||||
const recentMessages = [];
|
||||
for (
|
||||
let index = chat.length - 1;
|
||||
index >= 0 && recentMessages.length < limit;
|
||||
index--
|
||||
) {
|
||||
const message = chat[index];
|
||||
if (message?.is_system) continue;
|
||||
recentMessages.unshift(runtime.formatRecallContextLine(message));
|
||||
}
|
||||
|
||||
const normalizedSynthetic =
|
||||
runtime.normalizeRecallInputText(syntheticUserMessage);
|
||||
if (!normalizedSynthetic) return recentMessages;
|
||||
|
||||
const syntheticLine = `[user]: ${normalizedSynthetic}`;
|
||||
if (recentMessages[recentMessages.length - 1] !== syntheticLine) {
|
||||
recentMessages.push(syntheticLine);
|
||||
while (recentMessages.length > limit) {
|
||||
recentMessages.shift();
|
||||
}
|
||||
}
|
||||
|
||||
return recentMessages;
|
||||
}
|
||||
|
||||
export function getRecallUserMessageSourceLabelController(source) {
|
||||
switch (source) {
|
||||
case "send-intent":
|
||||
return "发送意图";
|
||||
case "chat-tail-user":
|
||||
return "当前用户楼层";
|
||||
case "message-sent":
|
||||
return "已发送用户楼层";
|
||||
case "chat-last-user":
|
||||
return "历史最后用户楼层";
|
||||
default:
|
||||
return "未知";
|
||||
}
|
||||
}
|
||||
|
||||
export function resolveRecallInputController(
|
||||
chat,
|
||||
recentContextMessageLimit,
|
||||
override = null,
|
||||
runtime,
|
||||
) {
|
||||
const overrideText = runtime.normalizeRecallInputText(
|
||||
override?.userMessage || override?.overrideUserMessage || "",
|
||||
);
|
||||
if (overrideText) {
|
||||
return {
|
||||
userMessage: overrideText,
|
||||
generationType: String(override?.generationType || "normal"),
|
||||
targetUserMessageIndex: Number.isFinite(override?.targetUserMessageIndex)
|
||||
? override.targetUserMessageIndex
|
||||
: null,
|
||||
source: String(
|
||||
override?.lockedSource ||
|
||||
override?.source ||
|
||||
override?.overrideSource ||
|
||||
"override",
|
||||
),
|
||||
sourceLabel: String(
|
||||
override?.lockedSourceLabel ||
|
||||
override?.sourceLabel ||
|
||||
override?.overrideSourceLabel ||
|
||||
"发送前拦截",
|
||||
),
|
||||
reason: String(
|
||||
override?.lockedReason ||
|
||||
override?.reason ||
|
||||
override?.overrideReason ||
|
||||
"override-bound",
|
||||
),
|
||||
sourceCandidates: Array.isArray(override?.sourceCandidates)
|
||||
? override.sourceCandidates.map((candidate) => ({ ...candidate }))
|
||||
: [],
|
||||
recentMessages: runtime.buildRecallRecentMessages(
|
||||
chat,
|
||||
recentContextMessageLimit,
|
||||
override?.includeSyntheticUserMessage === false ? "" : overrideText,
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
const latestUserMessage = runtime.getLatestUserChatMessage(chat);
|
||||
const latestUserText = runtime.normalizeRecallInputText(
|
||||
latestUserMessage?.mes || "",
|
||||
);
|
||||
const lastNonSystemMessage = runtime.getLastNonSystemChatMessage(chat);
|
||||
const tailUserText = lastNonSystemMessage?.is_user
|
||||
? runtime.normalizeRecallInputText(lastNonSystemMessage?.mes || "")
|
||||
: "";
|
||||
const pendingIntentText = runtime.isFreshRecallInputRecord(
|
||||
runtime.pendingRecallSendIntent,
|
||||
)
|
||||
? runtime.pendingRecallSendIntent.text
|
||||
: "";
|
||||
const sentUserText = runtime.isFreshRecallInputRecord(
|
||||
runtime.lastRecallSentUserMessage,
|
||||
)
|
||||
? runtime.lastRecallSentUserMessage.text
|
||||
: "";
|
||||
|
||||
let userMessage = "";
|
||||
let source = "";
|
||||
let syntheticUserMessage = "";
|
||||
|
||||
if (pendingIntentText) {
|
||||
userMessage = pendingIntentText;
|
||||
source = "send-intent";
|
||||
syntheticUserMessage = pendingIntentText;
|
||||
} else if (tailUserText) {
|
||||
userMessage = tailUserText;
|
||||
source = "chat-tail-user";
|
||||
} else if (sentUserText) {
|
||||
userMessage = sentUserText;
|
||||
source = "message-sent";
|
||||
if (!latestUserText || latestUserText !== sentUserText) {
|
||||
syntheticUserMessage = sentUserText;
|
||||
}
|
||||
} else if (latestUserText) {
|
||||
userMessage = latestUserText;
|
||||
source = "chat-last-user";
|
||||
}
|
||||
|
||||
return {
|
||||
userMessage,
|
||||
generationType: "normal",
|
||||
targetUserMessageIndex: null,
|
||||
source,
|
||||
sourceLabel: runtime.getRecallUserMessageSourceLabel(source),
|
||||
reason: userMessage ? `${source || "unknown"}-selected` : "no-recall-input",
|
||||
sourceCandidates: [],
|
||||
recentMessages: runtime.buildRecallRecentMessages(
|
||||
chat,
|
||||
recentContextMessageLimit,
|
||||
syntheticUserMessage,
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
export function applyRecallInjectionController(
|
||||
settings,
|
||||
recallInput,
|
||||
recentMessages,
|
||||
result,
|
||||
runtime,
|
||||
) {
|
||||
const injectionText = runtime
|
||||
.formatInjection(result, runtime.getSchema())
|
||||
.trim();
|
||||
runtime.setLastInjectionContent(injectionText);
|
||||
|
||||
const retrievalMeta = result?.meta?.retrieval || {};
|
||||
const llmMeta = retrievalMeta.llm || {
|
||||
status: settings.recallEnableLLM ? "unknown" : "disabled",
|
||||
reason: settings.recallEnableLLM ? "未提供 LLM 状态" : "LLM 精排已关闭",
|
||||
candidatePool: 0,
|
||||
};
|
||||
const deliveryMode =
|
||||
String(recallInput?.deliveryMode || "immediate").trim() || "immediate";
|
||||
|
||||
if (injectionText) {
|
||||
const tokens = runtime.estimateTokens(injectionText);
|
||||
debugLog(
|
||||
`[ST-BME] 注入 ${tokens} 估算 tokens, Core=${result.stats.coreCount}, Recall=${result.stats.recallCount}`,
|
||||
);
|
||||
runtime.persistRecallInjectionRecord?.({
|
||||
recallInput,
|
||||
result,
|
||||
injectionText,
|
||||
tokenEstimate: tokens,
|
||||
});
|
||||
}
|
||||
|
||||
let injectionTransport = {
|
||||
applied: false,
|
||||
source: "deferred",
|
||||
mode: "deferred",
|
||||
};
|
||||
if (deliveryMode === "immediate") {
|
||||
injectionTransport =
|
||||
runtime.applyModuleInjectionPrompt(injectionText, settings) ||
|
||||
injectionTransport;
|
||||
}
|
||||
runtime.recordInjectionSnapshot("recall", {
|
||||
taskType: "recall",
|
||||
source: recallInput.source,
|
||||
sourceLabel: recallInput.sourceLabel,
|
||||
reason: recallInput.reason || "",
|
||||
sourceCandidates: Array.isArray(recallInput.sourceCandidates)
|
||||
? recallInput.sourceCandidates.map((candidate) => ({ ...candidate }))
|
||||
: [],
|
||||
hookName: recallInput.hookName,
|
||||
recentMessages,
|
||||
selectedNodeIds: result.selectedNodeIds || [],
|
||||
retrievalMeta,
|
||||
llmMeta,
|
||||
stats: result.stats || {},
|
||||
injectionText,
|
||||
deliveryMode,
|
||||
applicationMode:
|
||||
deliveryMode === "immediate" ? "injection" : "pending-rewrite",
|
||||
rewrite: {
|
||||
applied: false,
|
||||
path: "",
|
||||
field: "",
|
||||
reason:
|
||||
deliveryMode === "immediate"
|
||||
? "immediate-injection"
|
||||
: "awaiting-generation-payload-rewrite",
|
||||
},
|
||||
transport: injectionTransport,
|
||||
});
|
||||
|
||||
runtime.setCurrentGraphLastRecallResult(result.selectedNodeIds);
|
||||
runtime.updateLastRecalledItems(result.selectedNodeIds || []);
|
||||
runtime.saveGraphToChat({ reason: "recall-result-updated" });
|
||||
|
||||
const llmLabel =
|
||||
llmMeta.status === "llm"
|
||||
? "LLM 精排完成"
|
||||
: llmMeta.status === "fallback"
|
||||
? "LLM 回退评分"
|
||||
: llmMeta.status === "disabled"
|
||||
? "仅评分排序"
|
||||
: "召回完成";
|
||||
const hookLabel = runtime.getRecallHookLabel(recallInput.hookName);
|
||||
runtime.setLastRecallStatus(
|
||||
llmLabel,
|
||||
[
|
||||
hookLabel,
|
||||
recallInput.sourceLabel,
|
||||
deliveryMode === "immediate" ? "即时注入" : "等待本轮 rewrite",
|
||||
`ctx ${recentMessages.length}`,
|
||||
`vector ${retrievalMeta.vectorHits ?? 0}`,
|
||||
retrievalMeta.vectorMergedHits
|
||||
? `merged ${retrievalMeta.vectorMergedHits}`
|
||||
: "",
|
||||
`diffusion ${retrievalMeta.diffusionHits ?? 0}`,
|
||||
retrievalMeta.candidatePoolAfterDpp
|
||||
? `dpp ${retrievalMeta.candidatePoolAfterDpp}`
|
||||
: "",
|
||||
`llm pool ${llmMeta.candidatePool ?? 0}`,
|
||||
`recall ${result.stats.recallCount}`,
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(" · "),
|
||||
llmMeta.status === "fallback" ? "warning" : "success",
|
||||
{
|
||||
syncRuntime: true,
|
||||
toastKind: "",
|
||||
},
|
||||
);
|
||||
|
||||
if (llmMeta.status === "fallback") {
|
||||
const now = Date.now();
|
||||
if (now - runtime.getLastRecallFallbackNoticeAt() > 15000) {
|
||||
runtime.setLastRecallFallbackNoticeAt(now);
|
||||
runtime.toastr.warning(
|
||||
llmMeta.reason || "LLM 精排未成功,已改用评分排序并继续注入记忆",
|
||||
"ST-BME 召回提示",
|
||||
{ timeOut: 4500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
injectionText,
|
||||
retrievalMeta,
|
||||
llmMeta,
|
||||
transport: injectionTransport,
|
||||
deliveryMode,
|
||||
};
|
||||
}
|
||||
|
||||
export async function runRecallController(runtime, options = {}) {
|
||||
if (runtime.getIsRecalling()) {
|
||||
runtime.abortRecallStageWithReason("旧召回已取消,正在启动新的召回");
|
||||
const settle = await runtime.waitForActiveRecallToSettle();
|
||||
if (!settle.settled && runtime.getIsRecalling()) {
|
||||
runtime.setLastRecallStatus(
|
||||
"召回忙",
|
||||
"上一轮召回仍在清理,请稍后重试",
|
||||
"warning",
|
||||
{
|
||||
syncRuntime: true,
|
||||
},
|
||||
);
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason: "上一轮召回仍在清理",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const hasGraph = !!runtime.getCurrentGraph();
|
||||
if (!hasGraph) {
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason: "当前无图谱",
|
||||
});
|
||||
}
|
||||
|
||||
const settings = runtime.getSettings();
|
||||
if (!settings.enabled || !settings.recallEnabled) {
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason: "召回功能未启用",
|
||||
});
|
||||
}
|
||||
const isReadableForRecall =
|
||||
typeof runtime.isGraphReadableForRecall === "function"
|
||||
? runtime.isGraphReadableForRecall()
|
||||
: runtime.isGraphReadable();
|
||||
if (!isReadableForRecall) {
|
||||
const reason = runtime.getGraphMutationBlockReason("召回");
|
||||
runtime.setLastRecallStatus("等待图谱加载", reason, "warning", {
|
||||
syncRuntime: true,
|
||||
});
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason,
|
||||
});
|
||||
}
|
||||
if (runtime.isGraphMetadataWriteAllowed()) {
|
||||
if (!(await runtime.recoverHistoryIfNeeded("pre-recall"))) {
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason: "历史恢复未就绪",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const context = runtime.getContext();
|
||||
const chat = context.chat;
|
||||
if (!chat || chat.length === 0) {
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason: "当前聊天为空",
|
||||
});
|
||||
}
|
||||
|
||||
const runId = runtime.nextRecallRunSequence();
|
||||
let recallPromise = null;
|
||||
recallPromise = (async () => {
|
||||
runtime.setIsRecalling(true);
|
||||
const recallController = runtime.beginStageAbortController("recall");
|
||||
const recallSignal = recallController.signal;
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
recallController.abort(
|
||||
options.signal.reason || runtime.createAbortError("宿主已终止生成"),
|
||||
);
|
||||
} else {
|
||||
options.signal.addEventListener(
|
||||
"abort",
|
||||
() =>
|
||||
recallController.abort(
|
||||
options.signal.reason ||
|
||||
runtime.createAbortError("宿主已终止生成"),
|
||||
),
|
||||
{ once: true },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await runtime.ensureVectorReadyIfNeeded("pre-recall", recallSignal);
|
||||
const recentContextMessageLimit = runtime.clampInt(
|
||||
settings.recallLlmContextMessages,
|
||||
4,
|
||||
0,
|
||||
20,
|
||||
);
|
||||
const recallInput = runtime.resolveRecallInput(
|
||||
chat,
|
||||
recentContextMessageLimit,
|
||||
options,
|
||||
);
|
||||
const userMessage = recallInput.userMessage;
|
||||
const recentMessages = recallInput.recentMessages;
|
||||
|
||||
if (!userMessage) {
|
||||
return runtime.createRecallRunResult("skipped", {
|
||||
reason: "当前没有可用于召回的用户输入",
|
||||
});
|
||||
}
|
||||
|
||||
recallInput.hookName = options.hookName || "";
|
||||
recallInput.deliveryMode =
|
||||
String(options.deliveryMode || "immediate").trim() || "immediate";
|
||||
|
||||
debugLog("[ST-BME] 开始召回", {
|
||||
source: recallInput.source,
|
||||
sourceLabel: recallInput.sourceLabel,
|
||||
hookName: recallInput.hookName,
|
||||
userMessageLength: userMessage.length,
|
||||
recentMessages: recentMessages.length,
|
||||
runId,
|
||||
});
|
||||
runtime.setLastRecallStatus(
|
||||
"召回中",
|
||||
[
|
||||
runtime.getRecallHookLabel(recallInput.hookName),
|
||||
`来源 ${recallInput.sourceLabel}`,
|
||||
`上下文 ${recentMessages.length} 条`,
|
||||
`当前用户消息长度 ${userMessage.length}`,
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(" · "),
|
||||
"running",
|
||||
{ syncRuntime: true },
|
||||
);
|
||||
if (recallInput.source === "send-intent") {
|
||||
runtime.setPendingRecallSendIntent(runtime.createRecallInputRecord());
|
||||
}
|
||||
|
||||
const cachedRecallPayload =
|
||||
options.cachedRecallPayload &&
|
||||
typeof options.cachedRecallPayload === "object"
|
||||
? options.cachedRecallPayload
|
||||
: null;
|
||||
if (cachedRecallPayload?.result) {
|
||||
// Cached planner handoff is already the authoritative source for this
|
||||
// generation, so any leftover send-intent snapshot must be cleared to
|
||||
// avoid leaking stale input into a later fallback recall path.
|
||||
runtime.setPendingRecallSendIntent?.(runtime.createRecallInputRecord());
|
||||
const cachedResult = cachedRecallPayload.result;
|
||||
const recentMessages = Array.isArray(cachedRecallPayload.recentMessages)
|
||||
? cachedRecallPayload.recentMessages.map((item) => String(item || ""))
|
||||
: recallInput.recentMessages;
|
||||
const applied = runtime.applyRecallInjection(
|
||||
settings,
|
||||
recallInput,
|
||||
recentMessages,
|
||||
cachedResult,
|
||||
);
|
||||
runtime.consumePlannerRecallHandoff?.(cachedRecallPayload.chatId, {
|
||||
handoffId: cachedRecallPayload.handoffId,
|
||||
});
|
||||
return runtime.createRecallRunResult("completed", {
|
||||
reason: cachedRecallPayload.reason || "planner-handoff-reused",
|
||||
selectedNodeIds: cachedResult.selectedNodeIds || [],
|
||||
injectionText: applied?.injectionText || "",
|
||||
retrievalMeta: applied?.retrievalMeta || {},
|
||||
llmMeta: applied?.llmMeta || {},
|
||||
transport: applied?.transport || {
|
||||
applied: false,
|
||||
source: "none",
|
||||
mode: "none",
|
||||
},
|
||||
deliveryMode:
|
||||
applied?.deliveryMode ||
|
||||
String(recallInput?.deliveryMode || "immediate").trim() ||
|
||||
"immediate",
|
||||
source: recallInput?.source || cachedRecallPayload.source || "",
|
||||
sourceLabel:
|
||||
recallInput?.sourceLabel || cachedRecallPayload.sourceLabel || "",
|
||||
hookName: recallInput?.hookName || "",
|
||||
sourceCandidates: Array.isArray(recallInput?.sourceCandidates)
|
||||
? recallInput.sourceCandidates.map((candidate) => ({
|
||||
...candidate,
|
||||
}))
|
||||
: [],
|
||||
stats: cachedResult?.stats || {},
|
||||
});
|
||||
}
|
||||
|
||||
const result = await runtime.retrieve({
|
||||
graph: runtime.getCurrentGraph(),
|
||||
userMessage,
|
||||
recentMessages,
|
||||
embeddingConfig: runtime.getEmbeddingConfig(),
|
||||
schema: runtime.getSchema(),
|
||||
signal: recallSignal,
|
||||
settings,
|
||||
onStreamProgress: ({ previewText, receivedChars }) => {
|
||||
const preview =
|
||||
previewText?.length > 60
|
||||
? "…" + previewText.slice(-60)
|
||||
: previewText || "";
|
||||
runtime.setLastRecallStatus(
|
||||
"AI 生成中",
|
||||
`${preview} [${receivedChars}字]`,
|
||||
"running",
|
||||
{ syncRuntime: true, noticeMarquee: true },
|
||||
);
|
||||
},
|
||||
options: runtime.buildRecallRetrieveOptions(settings, context),
|
||||
});
|
||||
|
||||
const applied = runtime.applyRecallInjection(
|
||||
settings,
|
||||
recallInput,
|
||||
recentMessages,
|
||||
result,
|
||||
);
|
||||
return runtime.createRecallRunResult("completed", {
|
||||
reason: "召回完成",
|
||||
selectedNodeIds: result.selectedNodeIds || [],
|
||||
injectionText: applied?.injectionText || "",
|
||||
retrievalMeta: applied?.retrievalMeta || {},
|
||||
llmMeta: applied?.llmMeta || {},
|
||||
transport: applied?.transport || {
|
||||
applied: false,
|
||||
source: "none",
|
||||
mode: "none",
|
||||
},
|
||||
deliveryMode:
|
||||
applied?.deliveryMode ||
|
||||
String(recallInput?.deliveryMode || "immediate").trim() ||
|
||||
"immediate",
|
||||
source: recallInput?.source || "",
|
||||
sourceLabel: recallInput?.sourceLabel || "",
|
||||
hookName: recallInput?.hookName || "",
|
||||
sourceCandidates: Array.isArray(recallInput?.sourceCandidates)
|
||||
? recallInput.sourceCandidates.map((candidate) => ({ ...candidate }))
|
||||
: [],
|
||||
stats: result?.stats || {},
|
||||
});
|
||||
} catch (e) {
|
||||
if (runtime.isAbortError(e)) {
|
||||
runtime.setLastRecallStatus(
|
||||
"召回已终止",
|
||||
e?.message || "已手动终止当前召回",
|
||||
"warning",
|
||||
{
|
||||
syncRuntime: true,
|
||||
},
|
||||
);
|
||||
return runtime.createRecallRunResult("aborted", {
|
||||
reason: e?.message || "召回已终止",
|
||||
});
|
||||
}
|
||||
runtime.console.error("[ST-BME] 召回失败:", e);
|
||||
const message = e?.message || String(e);
|
||||
runtime.setLastRecallStatus("召回失败", message, "error", {
|
||||
syncRuntime: true,
|
||||
toastKind: "",
|
||||
});
|
||||
runtime.toastr.error(`召回失败: ${message}`);
|
||||
return runtime.createRecallRunResult("failed", {
|
||||
reason: message,
|
||||
});
|
||||
} finally {
|
||||
runtime.finishStageAbortController("recall", recallController);
|
||||
runtime.setIsRecalling(false);
|
||||
if (runtime.getActiveRecallPromise() === recallPromise) {
|
||||
runtime.setActiveRecallPromise(null);
|
||||
}
|
||||
runtime.refreshPanelLiveState();
|
||||
}
|
||||
})();
|
||||
|
||||
runtime.setActiveRecallPromise(recallPromise);
|
||||
return await recallPromise;
|
||||
}
|
||||
186
retrieval/recall-persistence.js
Normal file
186
retrieval/recall-persistence.js
Normal file
@@ -0,0 +1,186 @@
|
||||
// ST-BME: 持久化召回记录纯函数
|
||||
|
||||
export const BME_RECALL_EXTRA_KEY = "bme_recall";
|
||||
export const BME_RECALL_VERSION = 1;
|
||||
|
||||
function toIsoString(value) {
|
||||
if (typeof value === "string" && value.trim()) return value;
|
||||
return new Date().toISOString();
|
||||
}
|
||||
|
||||
function cloneStringArray(value) {
|
||||
return Array.isArray(value)
|
||||
? value
|
||||
.map((item) => String(item || "").trim())
|
||||
.filter(Boolean)
|
||||
: [];
|
||||
}
|
||||
|
||||
function cloneRecord(value) {
|
||||
if (!value || typeof value !== "object" || Array.isArray(value)) return null;
|
||||
return { ...value };
|
||||
}
|
||||
|
||||
export function readPersistedRecallFromUserMessage(chat, userMessageIndex) {
|
||||
if (!Array.isArray(chat) || !Number.isFinite(userMessageIndex)) return null;
|
||||
const message = chat[userMessageIndex];
|
||||
const raw = message?.extra?.[BME_RECALL_EXTRA_KEY];
|
||||
const record = cloneRecord(raw);
|
||||
if (!record) return null;
|
||||
|
||||
const injectionText = String(record.injectionText || "").trim();
|
||||
if (!injectionText) return null;
|
||||
|
||||
return {
|
||||
version: Number.isFinite(Number(record.version))
|
||||
? Number(record.version)
|
||||
: BME_RECALL_VERSION,
|
||||
injectionText,
|
||||
selectedNodeIds: cloneStringArray(record.selectedNodeIds),
|
||||
recallInput: String(record.recallInput || ""),
|
||||
recallSource: String(record.recallSource || ""),
|
||||
hookName: String(record.hookName || ""),
|
||||
tokenEstimate: Number.isFinite(Number(record.tokenEstimate))
|
||||
? Number(record.tokenEstimate)
|
||||
: 0,
|
||||
createdAt: toIsoString(record.createdAt),
|
||||
updatedAt: toIsoString(record.updatedAt),
|
||||
generationCount: Math.max(0, Number.parseInt(record.generationCount, 10) || 0),
|
||||
manuallyEdited: Boolean(record.manuallyEdited),
|
||||
};
|
||||
}
|
||||
|
||||
export function buildPersistedRecallRecord(payload = {}, existingRecord = null) {
|
||||
const nowIso = toIsoString(payload.nowIso);
|
||||
const previous = cloneRecord(existingRecord) || {};
|
||||
const injectionText = String(payload.injectionText || "").trim();
|
||||
|
||||
return {
|
||||
version: BME_RECALL_VERSION,
|
||||
injectionText,
|
||||
selectedNodeIds: cloneStringArray(payload.selectedNodeIds),
|
||||
recallInput: String(payload.recallInput || ""),
|
||||
recallSource: String(payload.recallSource || ""),
|
||||
hookName: String(payload.hookName || ""),
|
||||
tokenEstimate: Number.isFinite(Number(payload.tokenEstimate))
|
||||
? Number(payload.tokenEstimate)
|
||||
: 0,
|
||||
createdAt: toIsoString(previous.createdAt || nowIso),
|
||||
updatedAt: nowIso,
|
||||
generationCount: 0,
|
||||
manuallyEdited: Boolean(payload.manuallyEdited),
|
||||
};
|
||||
}
|
||||
|
||||
export function writePersistedRecallToUserMessage(chat, userMessageIndex, record) {
|
||||
if (!Array.isArray(chat) || !Number.isFinite(userMessageIndex)) return false;
|
||||
const message = chat[userMessageIndex];
|
||||
if (!message || !message.is_user) return false;
|
||||
|
||||
const normalized = cloneRecord(record);
|
||||
if (!normalized || !String(normalized.injectionText || "").trim()) return false;
|
||||
|
||||
message.extra ||= {};
|
||||
message.extra[BME_RECALL_EXTRA_KEY] = normalized;
|
||||
return true;
|
||||
}
|
||||
|
||||
export function removePersistedRecallFromUserMessage(chat, userMessageIndex) {
|
||||
if (!Array.isArray(chat) || !Number.isFinite(userMessageIndex)) return false;
|
||||
const message = chat[userMessageIndex];
|
||||
if (!message?.extra || typeof message.extra !== "object") return false;
|
||||
if (!(BME_RECALL_EXTRA_KEY in message.extra)) return false;
|
||||
delete message.extra[BME_RECALL_EXTRA_KEY];
|
||||
return true;
|
||||
}
|
||||
|
||||
export function markPersistedRecallManualEdit(
|
||||
chat,
|
||||
userMessageIndex,
|
||||
manuallyEdited = true,
|
||||
nowIso = new Date().toISOString(),
|
||||
) {
|
||||
const current = readPersistedRecallFromUserMessage(chat, userMessageIndex);
|
||||
if (!current) return null;
|
||||
const nextRecord = {
|
||||
...current,
|
||||
manuallyEdited: Boolean(manuallyEdited),
|
||||
updatedAt: toIsoString(nowIso),
|
||||
};
|
||||
if (!writePersistedRecallToUserMessage(chat, userMessageIndex, nextRecord)) {
|
||||
return null;
|
||||
}
|
||||
return nextRecord;
|
||||
}
|
||||
|
||||
export function bumpPersistedRecallGenerationCount(chat, userMessageIndex) {
|
||||
const current = readPersistedRecallFromUserMessage(chat, userMessageIndex);
|
||||
if (!current) return null;
|
||||
const nextRecord = {
|
||||
...current,
|
||||
generationCount: Math.max(0, Number(current.generationCount || 0)) + 1,
|
||||
};
|
||||
if (!writePersistedRecallToUserMessage(chat, userMessageIndex, nextRecord)) {
|
||||
return null;
|
||||
}
|
||||
return nextRecord;
|
||||
}
|
||||
|
||||
export function resolveGenerationTargetUserMessageIndex(
|
||||
chat,
|
||||
{ generationType = "normal" } = {},
|
||||
) {
|
||||
if (!Array.isArray(chat) || chat.length === 0) return null;
|
||||
|
||||
const normalizedType = String(generationType || "normal").trim() || "normal";
|
||||
|
||||
// normal:取「最后一条非系统用户楼层」。若直接 return 末条非 user(常见为刚追加的助手回合),
|
||||
// 会得到 null,导致持久化无法回绑到本轮 user,`hasRecordForLatest` 长期为 false。
|
||||
if (normalizedType === "normal") {
|
||||
for (let index = chat.length - 1; index >= 0; index--) {
|
||||
const message = chat[index];
|
||||
if (message?.is_system) continue;
|
||||
if (message?.is_user) return index;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
for (let index = chat.length - 1; index >= 0; index--) {
|
||||
if (chat[index]?.is_user) return index;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
export function resolveFinalRecallInjectionSource({
|
||||
freshRecallResult = null,
|
||||
persistedRecord = null,
|
||||
} = {}) {
|
||||
const freshInjection = String(freshRecallResult?.injectionText || "").trim();
|
||||
if (
|
||||
freshRecallResult?.status === "completed" &&
|
||||
freshRecallResult?.didRecall &&
|
||||
freshInjection
|
||||
) {
|
||||
return {
|
||||
source: "fresh",
|
||||
injectionText: freshInjection,
|
||||
record: null,
|
||||
};
|
||||
}
|
||||
|
||||
const persistedInjection = String(persistedRecord?.injectionText || "").trim();
|
||||
if (persistedInjection) {
|
||||
return {
|
||||
source: "persisted",
|
||||
injectionText: persistedInjection,
|
||||
record: persistedRecord,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
source: "none",
|
||||
injectionText: "",
|
||||
record: null,
|
||||
};
|
||||
}
|
||||
795
retrieval/retrieval-enhancer.js
Normal file
795
retrieval/retrieval-enhancer.js
Normal file
@@ -0,0 +1,795 @@
|
||||
import { embedText, searchSimilar } from "../vector/embedding.js";
|
||||
import { getNode } from "../graph/graph.js";
|
||||
import { isDirectVectorConfig } from "../vector/vector-index.js";
|
||||
|
||||
const COOCCURRENCE_EXCLUDED_TYPES = new Set([
|
||||
"event",
|
||||
"synopsis",
|
||||
"reflection",
|
||||
]);
|
||||
|
||||
const cooccurrenceCache = new WeakMap();
|
||||
|
||||
export function splitIntentSegments(
|
||||
text,
|
||||
{ maxSegments = 4, minLength = 3 } = {},
|
||||
) {
|
||||
const raw = String(text || "").trim();
|
||||
if (!raw) return [];
|
||||
|
||||
const segments = raw
|
||||
.split(/[,,。.;;!!??\n]+|(?:顺便|另外|还有|对了|然后|而且|并且|同时)/)
|
||||
.map((item) => item.trim())
|
||||
.filter((item) => item.length >= minLength);
|
||||
|
||||
return uniqueStrings(segments).slice(0, Math.max(1, maxSegments));
|
||||
}
|
||||
|
||||
export function mergeVectorResults(resultGroups = [], limit = Infinity) {
|
||||
const merged = new Map();
|
||||
let rawHitCount = 0;
|
||||
|
||||
for (const group of resultGroups) {
|
||||
for (const item of Array.isArray(group) ? group : []) {
|
||||
if (!item?.nodeId) continue;
|
||||
rawHitCount += 1;
|
||||
const score = Number(item.score) || 0;
|
||||
const existing = merged.get(item.nodeId);
|
||||
if (!existing || score > existing.score) {
|
||||
merged.set(item.nodeId, { ...item, score });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const results = [...merged.values()]
|
||||
.sort((a, b) => {
|
||||
if (b.score !== a.score) return b.score - a.score;
|
||||
return String(a.nodeId).localeCompare(String(b.nodeId));
|
||||
})
|
||||
.slice(0, Number.isFinite(limit) ? limit : merged.size);
|
||||
|
||||
return {
|
||||
rawHitCount,
|
||||
results,
|
||||
};
|
||||
}
|
||||
|
||||
export function isEligibleAnchorNode(node) {
|
||||
if (!node || node.archived) return false;
|
||||
if (COOCCURRENCE_EXCLUDED_TYPES.has(node.type)) return false;
|
||||
return getAnchorTerms(node).length > 0;
|
||||
}
|
||||
|
||||
export function getAnchorTerms(node) {
|
||||
return [node?.fields?.name, node?.fields?.title]
|
||||
.filter((value) => typeof value === "string")
|
||||
.map((value) => value.trim())
|
||||
.filter((value) => value.length >= 2);
|
||||
}
|
||||
|
||||
export function collectSupplementalAnchorNodeIds(
|
||||
graph,
|
||||
vectorResults = [],
|
||||
primaryAnchorIds = [],
|
||||
maxCount = 5,
|
||||
) {
|
||||
const selected = [];
|
||||
const seen = new Set(primaryAnchorIds || []);
|
||||
|
||||
for (const result of vectorResults) {
|
||||
if (selected.length >= maxCount) break;
|
||||
const node = getNode(graph, result?.nodeId);
|
||||
if (!isEligibleAnchorNode(node) || seen.has(node.id)) continue;
|
||||
seen.add(node.id);
|
||||
selected.push(node.id);
|
||||
}
|
||||
|
||||
return selected;
|
||||
}
|
||||
|
||||
export function createCooccurrenceIndex(
|
||||
graph,
|
||||
{
|
||||
maxAnchorsPerBatch = 10,
|
||||
eligibleNodes = null,
|
||||
} = {},
|
||||
) {
|
||||
const nodes = Array.isArray(eligibleNodes)
|
||||
? eligibleNodes.filter(isEligibleAnchorNode)
|
||||
: [];
|
||||
const eligibleNodeKey = nodes.map((node) => node.id).sort().join("|");
|
||||
const cacheKey = [
|
||||
graph?.batchJournal?.length || 0,
|
||||
graph?.nodes?.length || 0,
|
||||
graph?.historyState?.lastProcessedAssistantFloor ?? -1,
|
||||
maxAnchorsPerBatch,
|
||||
eligibleNodeKey,
|
||||
].join(":");
|
||||
const cached = cooccurrenceCache.get(graph);
|
||||
if (cached?.key === cacheKey) {
|
||||
return cached.value;
|
||||
}
|
||||
|
||||
const index = new Map();
|
||||
let pairCount = 0;
|
||||
let batchCount = 0;
|
||||
let source = "seqRange";
|
||||
|
||||
if (nodes.length >= 2 && Array.isArray(graph?.batchJournal)) {
|
||||
for (const journal of graph.batchJournal) {
|
||||
const range = Array.isArray(journal?.processedRange)
|
||||
? journal.processedRange
|
||||
: null;
|
||||
if (!range || !Number.isFinite(range[0]) || !Number.isFinite(range[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const batchNodes = nodes
|
||||
.filter((node) => rangesOverlap(node.seqRange, range))
|
||||
.sort(compareBySeqDesc)
|
||||
.slice(0, Math.max(2, maxAnchorsPerBatch));
|
||||
if (batchNodes.length < 2) continue;
|
||||
|
||||
batchCount += 1;
|
||||
pairCount += appendPairs(index, batchNodes, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (batchCount === 0) {
|
||||
source = "seqRange";
|
||||
pairCount = 0;
|
||||
index.clear();
|
||||
|
||||
for (let i = 0; i < nodes.length; i++) {
|
||||
for (let j = i + 1; j < nodes.length; j++) {
|
||||
const overlap = rangeOverlapSize(nodes[i].seqRange, nodes[j].seqRange);
|
||||
if (overlap <= 0) continue;
|
||||
addCooccurrence(index, nodes[i].id, nodes[j].id, overlap);
|
||||
addCooccurrence(index, nodes[j].id, nodes[i].id, overlap);
|
||||
pairCount += 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
source = "batchJournal";
|
||||
}
|
||||
|
||||
const result = {
|
||||
map: normalizeCooccurrenceMap(index),
|
||||
source,
|
||||
batchCount,
|
||||
pairCount,
|
||||
};
|
||||
cooccurrenceCache.set(graph, { key: cacheKey, value: result });
|
||||
return result;
|
||||
}
|
||||
|
||||
export function applyCooccurrenceBoost(
|
||||
baseScores,
|
||||
anchorWeights,
|
||||
cooccurrenceIndex,
|
||||
{ scale = 0.1, maxNeighbors = 10 } = {},
|
||||
) {
|
||||
const nextScores = new Map(baseScores || []);
|
||||
const boostedNodes = [];
|
||||
const map = cooccurrenceIndex?.map instanceof Map
|
||||
? cooccurrenceIndex.map
|
||||
: new Map();
|
||||
|
||||
for (const [anchorId, anchorScore] of anchorWeights.entries()) {
|
||||
const neighbors = map.get(anchorId) || [];
|
||||
const capped = neighbors.slice(0, Math.max(1, maxNeighbors));
|
||||
|
||||
for (const item of capped) {
|
||||
const bonus =
|
||||
Math.max(0, Number(anchorScore) || 0) *
|
||||
Math.log(1 + Math.max(0, Number(item.count) || 0)) *
|
||||
Math.max(0, Number(scale) || 0);
|
||||
if (!bonus) continue;
|
||||
|
||||
nextScores.set(item.nodeId, (nextScores.get(item.nodeId) || 0) + bonus);
|
||||
boostedNodes.push({
|
||||
anchorId,
|
||||
nodeId: item.nodeId,
|
||||
count: item.count,
|
||||
bonus,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
scores: nextScores,
|
||||
boostedNodes,
|
||||
};
|
||||
}
|
||||
|
||||
export function dppGreedySelect(
|
||||
candidateVecs = [],
|
||||
candidateScores = [],
|
||||
k,
|
||||
qualityWeight = 1,
|
||||
) {
|
||||
const total = Math.min(candidateVecs.length, candidateScores.length);
|
||||
const target = Math.max(0, Math.min(k, total));
|
||||
if (target >= total) {
|
||||
return Array.from({ length: total }, (_, index) => index);
|
||||
}
|
||||
|
||||
const normalized = candidateVecs.map((vector) => normalizeVector(vector));
|
||||
const q = candidateScores.map((score) =>
|
||||
Math.pow(Math.max(Number(score) || 0, 1e-10), Math.max(0, qualityWeight)),
|
||||
);
|
||||
const diag = q.map((value) => value * value + 1e-8);
|
||||
const chol = Array.from({ length: target }, () =>
|
||||
Array(total).fill(0),
|
||||
);
|
||||
const selected = [];
|
||||
|
||||
for (let j = 0; j < target; j++) {
|
||||
let bestIndex = -1;
|
||||
let bestValue = Number.NEGATIVE_INFINITY;
|
||||
|
||||
for (let i = 0; i < total; i++) {
|
||||
if (selected.includes(i)) continue;
|
||||
if (diag[i] > bestValue) {
|
||||
bestValue = diag[i];
|
||||
bestIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (bestIndex === -1) break;
|
||||
selected.push(bestIndex);
|
||||
|
||||
if (j === target - 1 || diag[bestIndex] < 1e-10) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const row = normalized.map(
|
||||
(vector, index) => q[bestIndex] * dot(normalized[bestIndex], vector) * q[index],
|
||||
);
|
||||
const next = [...row];
|
||||
for (let i = 0; i < j; i++) {
|
||||
const pivot = chol[i][bestIndex];
|
||||
for (let index = 0; index < total; index++) {
|
||||
next[index] -= pivot * chol[i][index];
|
||||
}
|
||||
}
|
||||
|
||||
const inv = 1 / Math.sqrt(diag[bestIndex]);
|
||||
for (let index = 0; index < total; index++) {
|
||||
chol[j][index] = next[index] * inv;
|
||||
diag[index] = Math.max(0, diag[index] - chol[j][index] ** 2);
|
||||
}
|
||||
}
|
||||
|
||||
return selected;
|
||||
}
|
||||
|
||||
export function applyDiversitySampling(
|
||||
candidates = [],
|
||||
{ k, qualityWeight = 1 } = {},
|
||||
) {
|
||||
const target = Math.max(1, Math.floor(Number(k) || 0));
|
||||
if (candidates.length <= target) {
|
||||
return {
|
||||
applied: false,
|
||||
reason: "candidate-pool-too-small",
|
||||
selected: candidates.slice(0, target),
|
||||
beforeCount: candidates.length,
|
||||
afterCount: Math.min(candidates.length, target),
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
candidates.some(
|
||||
(item) =>
|
||||
!Array.isArray(item?.node?.embedding) || item.node.embedding.length === 0,
|
||||
)
|
||||
) {
|
||||
return {
|
||||
applied: false,
|
||||
reason: "candidate-embeddings-missing",
|
||||
selected: candidates.slice(0, target),
|
||||
beforeCount: candidates.length,
|
||||
afterCount: Math.min(candidates.length, target),
|
||||
};
|
||||
}
|
||||
|
||||
const indexes = dppGreedySelect(
|
||||
candidates.map((item) => item.node.embedding),
|
||||
candidates.map((item) => item.finalScore),
|
||||
target,
|
||||
qualityWeight,
|
||||
);
|
||||
|
||||
const selected = indexes
|
||||
.map((index) => candidates[index])
|
||||
.filter(Boolean);
|
||||
|
||||
if (selected.length !== target) {
|
||||
return {
|
||||
applied: false,
|
||||
reason: "dpp-selection-incomplete",
|
||||
selected: candidates.slice(0, target),
|
||||
beforeCount: candidates.length,
|
||||
afterCount: Math.min(candidates.length, target),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
applied: true,
|
||||
reason: "",
|
||||
selected,
|
||||
beforeCount: candidates.length,
|
||||
afterCount: selected.length,
|
||||
};
|
||||
}
|
||||
|
||||
export function nmfQueryAnalysis(
|
||||
queryVec,
|
||||
entityVecs,
|
||||
{ nTopics = 15, maxIter = 100, tolerance = 1e-4 } = {},
|
||||
) {
|
||||
const vectors = normalizeMatrix(entityVecs);
|
||||
const query = vectorAbs(queryVec);
|
||||
if (vectors.length < 2 || query.length === 0) {
|
||||
return {
|
||||
semanticDepth: 0,
|
||||
topicCoverage: 0,
|
||||
novelty: 1,
|
||||
topTopics: [],
|
||||
};
|
||||
}
|
||||
|
||||
const k = Math.min(Math.max(1, Math.floor(nTopics)), vectors.length);
|
||||
const matrix = vectors.map((vector) => vectorAbs(vector));
|
||||
const { h } = nmfMultiplicativeUpdate(matrix, k, maxIter, tolerance);
|
||||
const rawScores = h.map((topic) => dot(query, topic));
|
||||
const topics = softmax(rawScores);
|
||||
|
||||
const entropy = -topics.reduce((sum, value) => {
|
||||
return value > 1e-10 ? sum + value * Math.log(value) : sum;
|
||||
}, 0);
|
||||
const maxEntropy = k > 1 ? Math.log(k) : 1;
|
||||
const semanticDepth = 1 - entropy / maxEntropy;
|
||||
const topicCoverage = topics.filter((value) => value > 0.5 / k).length;
|
||||
const reconstruction = Array(query.length).fill(0);
|
||||
|
||||
for (let topicIndex = 0; topicIndex < topics.length; topicIndex++) {
|
||||
const weight = topics[topicIndex];
|
||||
for (let dim = 0; dim < reconstruction.length; dim++) {
|
||||
reconstruction[dim] += weight * h[topicIndex][dim];
|
||||
}
|
||||
}
|
||||
|
||||
const novelty =
|
||||
l2Norm(subtractVectors(query, reconstruction)) / Math.max(l2Norm(query), 1e-10);
|
||||
|
||||
return {
|
||||
semanticDepth,
|
||||
topicCoverage,
|
||||
novelty,
|
||||
topTopics: topics,
|
||||
};
|
||||
}
|
||||
|
||||
export function sparseCodeResidual(
|
||||
queryVec,
|
||||
entityVecs,
|
||||
{ lambda = 0.1, maxIter = 80 } = {},
|
||||
) {
|
||||
const query = normalizeVector(queryVec, false);
|
||||
const entities = normalizeMatrix(entityVecs);
|
||||
const total = entities.length;
|
||||
if (total === 0 || query.length === 0) {
|
||||
return {
|
||||
alpha: [],
|
||||
residual: [...query],
|
||||
residualNorm: l2Norm(query),
|
||||
};
|
||||
}
|
||||
|
||||
const gram = Array.from({ length: total }, () => Array(total).fill(0));
|
||||
const etq = Array(total).fill(0);
|
||||
|
||||
for (let i = 0; i < total; i++) {
|
||||
etq[i] = dot(entities[i], query);
|
||||
for (let j = i; j < total; j++) {
|
||||
const value = dot(entities[i], entities[j]);
|
||||
gram[i][j] = value;
|
||||
gram[j][i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
let lipschitz = 0;
|
||||
for (let i = 0; i < total; i++) {
|
||||
const rowSum = gram[i].reduce((sum, value) => sum + Math.abs(value), 0);
|
||||
lipschitz = Math.max(lipschitz, rowSum);
|
||||
}
|
||||
if (lipschitz < 1e-10) {
|
||||
return {
|
||||
alpha: Array(total).fill(0),
|
||||
residual: [...query],
|
||||
residualNorm: l2Norm(query),
|
||||
};
|
||||
}
|
||||
|
||||
const step = 1 / lipschitz;
|
||||
let alpha = Array(total).fill(0);
|
||||
let y = [...alpha];
|
||||
let t = 1;
|
||||
|
||||
for (let iteration = 0; iteration < maxIter; iteration++) {
|
||||
const grad = matVecMul(gram, y).map((value, index) => value - etq[index]);
|
||||
const nextAlpha = softThreshold(
|
||||
y.map((value, index) => value - step * grad[index]),
|
||||
lambda * step,
|
||||
);
|
||||
const nextT = (1 + Math.sqrt(1 + 4 * t * t)) / 2;
|
||||
const momentum = (t - 1) / nextT;
|
||||
y = nextAlpha.map(
|
||||
(value, index) => value + momentum * (value - alpha[index]),
|
||||
);
|
||||
alpha = nextAlpha;
|
||||
t = nextT;
|
||||
}
|
||||
|
||||
const reconstruction = Array(query.length).fill(0);
|
||||
for (let i = 0; i < total; i++) {
|
||||
if (Math.abs(alpha[i]) < 1e-10) continue;
|
||||
for (let dim = 0; dim < query.length; dim++) {
|
||||
reconstruction[dim] += alpha[i] * entities[i][dim];
|
||||
}
|
||||
}
|
||||
|
||||
const residual = subtractVectors(query, reconstruction);
|
||||
return {
|
||||
alpha,
|
||||
residual,
|
||||
residualNorm: l2Norm(residual),
|
||||
};
|
||||
}
|
||||
|
||||
export async function runResidualRecall({
|
||||
queryText,
|
||||
graph,
|
||||
embeddingConfig,
|
||||
basisNodes = [],
|
||||
candidateNodes = [],
|
||||
basisLimit = 24,
|
||||
nTopics = 15,
|
||||
noveltyThreshold = 0.4,
|
||||
residualThreshold = 0.3,
|
||||
residualTopK = 5,
|
||||
signal,
|
||||
}) {
|
||||
if (!isDirectVectorConfig(embeddingConfig)) {
|
||||
return {
|
||||
triggered: false,
|
||||
hits: [],
|
||||
skipReason: "residual-direct-mode-required",
|
||||
};
|
||||
}
|
||||
|
||||
const filteredBasis = basisNodes
|
||||
.filter(
|
||||
(node) =>
|
||||
Array.isArray(node?.embedding) && node.embedding.length > 0,
|
||||
)
|
||||
.slice(0, Math.max(2, basisLimit));
|
||||
if (filteredBasis.length < 2) {
|
||||
return {
|
||||
triggered: false,
|
||||
hits: [],
|
||||
skipReason: "residual-basis-insufficient",
|
||||
};
|
||||
}
|
||||
|
||||
const queryVec = await embedText(queryText, embeddingConfig, { signal });
|
||||
if (!queryVec || queryVec.length === 0) {
|
||||
return {
|
||||
triggered: false,
|
||||
hits: [],
|
||||
skipReason: "residual-query-embedding-missing",
|
||||
};
|
||||
}
|
||||
|
||||
const nmfResult = nmfQueryAnalysis(queryVec, filteredBasis.map((node) => node.embedding), {
|
||||
nTopics,
|
||||
});
|
||||
if (!Number.isFinite(nmfResult.novelty) || nmfResult.novelty < noveltyThreshold) {
|
||||
return {
|
||||
triggered: false,
|
||||
hits: [],
|
||||
nmf: nmfResult,
|
||||
skipReason: "residual-novelty-below-threshold",
|
||||
};
|
||||
}
|
||||
|
||||
const sparse = sparseCodeResidual(queryVec, filteredBasis.map((node) => node.embedding));
|
||||
if (!Number.isFinite(sparse.residualNorm) || sparse.residualNorm <= residualThreshold) {
|
||||
return {
|
||||
triggered: false,
|
||||
hits: [],
|
||||
nmf: nmfResult,
|
||||
sparse,
|
||||
skipReason: "residual-norm-below-threshold",
|
||||
};
|
||||
}
|
||||
|
||||
const searchableCandidates = (candidateNodes || [])
|
||||
.filter(
|
||||
(node) =>
|
||||
Array.isArray(node?.embedding) &&
|
||||
node.embedding.length > 0 &&
|
||||
!filteredBasis.some((basisNode) => basisNode.id === node.id),
|
||||
)
|
||||
.map((node) => ({
|
||||
nodeId: node.id,
|
||||
embedding: node.embedding,
|
||||
}));
|
||||
|
||||
if (searchableCandidates.length === 0) {
|
||||
return {
|
||||
triggered: true,
|
||||
hits: [],
|
||||
nmf: nmfResult,
|
||||
sparse,
|
||||
skipReason: "residual-search-space-empty",
|
||||
};
|
||||
}
|
||||
|
||||
const hits = searchSimilar(sparse.residual, searchableCandidates, residualTopK)
|
||||
.map((item) => ({
|
||||
...item,
|
||||
node: getNode(graph, item.nodeId),
|
||||
}))
|
||||
.filter((item) => item.node);
|
||||
|
||||
return {
|
||||
triggered: true,
|
||||
hits,
|
||||
nmf: nmfResult,
|
||||
sparse,
|
||||
skipReason: hits.length > 0 ? "" : "residual-no-hit",
|
||||
};
|
||||
}
|
||||
|
||||
function uniqueStrings(items = []) {
|
||||
return [...new Set(items.filter(Boolean))];
|
||||
}
|
||||
|
||||
function normalizeCooccurrenceMap(index) {
|
||||
const normalized = new Map();
|
||||
for (const [nodeId, neighborMap] of index.entries()) {
|
||||
normalized.set(
|
||||
nodeId,
|
||||
[...neighborMap.entries()]
|
||||
.map(([neighborId, count]) => ({ nodeId: neighborId, count }))
|
||||
.sort((a, b) => {
|
||||
if (b.count !== a.count) return b.count - a.count;
|
||||
return String(a.nodeId).localeCompare(String(b.nodeId));
|
||||
}),
|
||||
);
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function appendPairs(index, nodes, increment) {
|
||||
let count = 0;
|
||||
for (let i = 0; i < nodes.length; i++) {
|
||||
for (let j = i + 1; j < nodes.length; j++) {
|
||||
addCooccurrence(index, nodes[i].id, nodes[j].id, increment);
|
||||
addCooccurrence(index, nodes[j].id, nodes[i].id, increment);
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
function addCooccurrence(index, fromId, toId, increment) {
|
||||
if (!index.has(fromId)) {
|
||||
index.set(fromId, new Map());
|
||||
}
|
||||
const map = index.get(fromId);
|
||||
map.set(toId, (map.get(toId) || 0) + increment);
|
||||
}
|
||||
|
||||
function rangesOverlap(a, b) {
|
||||
return rangeOverlapSize(a, b) > 0;
|
||||
}
|
||||
|
||||
function rangeOverlapSize(a, b) {
|
||||
const rangeA = normalizeRange(a);
|
||||
const rangeB = normalizeRange(b);
|
||||
if (!rangeA || !rangeB) return 0;
|
||||
const start = Math.max(rangeA[0], rangeB[0]);
|
||||
const end = Math.min(rangeA[1], rangeB[1]);
|
||||
return end >= start ? end - start + 1 : 0;
|
||||
}
|
||||
|
||||
function normalizeRange(range) {
|
||||
if (!Array.isArray(range) || range.length < 2) return null;
|
||||
const start = Number(range[0]);
|
||||
const end = Number(range[1]);
|
||||
if (!Number.isFinite(start) || !Number.isFinite(end)) return null;
|
||||
return [Math.min(start, end), Math.max(start, end)];
|
||||
}
|
||||
|
||||
function compareBySeqDesc(a, b) {
|
||||
const seqA = a?.seqRange?.[1] ?? a?.seq ?? 0;
|
||||
const seqB = b?.seqRange?.[1] ?? b?.seq ?? 0;
|
||||
if (seqB !== seqA) return seqB - seqA;
|
||||
return (b.importance || 0) - (a.importance || 0);
|
||||
}
|
||||
|
||||
function vectorAbs(vector = []) {
|
||||
return vector.map((value) => Math.abs(Number(value) || 0));
|
||||
}
|
||||
|
||||
function normalizeVector(vector = [], useUnitNorm = true) {
|
||||
const normalized = vector.map((value) => Number(value) || 0);
|
||||
if (!useUnitNorm) return normalized;
|
||||
const norm = l2Norm(normalized);
|
||||
if (norm < 1e-10) return normalized.map(() => 0);
|
||||
return normalized.map((value) => value / norm);
|
||||
}
|
||||
|
||||
function normalizeMatrix(vectors = []) {
|
||||
return vectors
|
||||
.filter((vector) => Array.isArray(vector) && vector.length > 0)
|
||||
.map((vector) => normalizeVector(vector));
|
||||
}
|
||||
|
||||
function dot(a = [], b = []) {
|
||||
const length = Math.min(a.length, b.length);
|
||||
let sum = 0;
|
||||
for (let index = 0; index < length; index++) {
|
||||
sum += (Number(a[index]) || 0) * (Number(b[index]) || 0);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
function l2Norm(vector = []) {
|
||||
return Math.sqrt(vector.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
function subtractVectors(a = [], b = []) {
|
||||
const length = Math.max(a.length, b.length);
|
||||
const result = Array(length).fill(0);
|
||||
for (let index = 0; index < length; index++) {
|
||||
result[index] = (Number(a[index]) || 0) - (Number(b[index]) || 0);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function matVecMul(matrix = [], vector = []) {
|
||||
return matrix.map((row) => dot(row, vector));
|
||||
}
|
||||
|
||||
function softThreshold(vector = [], threshold = 0) {
|
||||
return vector.map((value) => {
|
||||
const absValue = Math.abs(value);
|
||||
if (absValue <= threshold) return 0;
|
||||
return Math.sign(value) * (absValue - threshold);
|
||||
});
|
||||
}
|
||||
|
||||
function softmax(values = []) {
|
||||
if (values.length === 0) return [];
|
||||
const max = Math.max(...values);
|
||||
const exp = values.map((value) => Math.exp(value - max));
|
||||
const total = exp.reduce((sum, value) => sum + value, 0) || 1;
|
||||
return exp.map((value) => value / total);
|
||||
}
|
||||
|
||||
function nmfMultiplicativeUpdate(matrix, k, maxIter, tolerance) {
|
||||
const m = matrix.length;
|
||||
const d = matrix[0]?.length || 0;
|
||||
const mean =
|
||||
matrix.reduce((sum, row) => sum + row.reduce((acc, value) => acc + value, 0), 0) /
|
||||
Math.max(1, m * d) || 0.01;
|
||||
const avg = Math.max(Math.sqrt(mean / Math.max(1, k)), 0.01);
|
||||
const rand = createDeterministicRandom(42);
|
||||
const w = Array.from({ length: m }, () =>
|
||||
Array.from({ length: k }, () => Math.abs(avg + avg * 0.5 * (rand() - 0.5)) + 1e-6),
|
||||
);
|
||||
const h = Array.from({ length: k }, () =>
|
||||
Array.from({ length: d }, () => Math.abs(avg + avg * 0.5 * (rand() - 0.5)) + 1e-6),
|
||||
);
|
||||
const eps = 1e-10;
|
||||
|
||||
for (let iteration = 0; iteration < maxIter; iteration++) {
|
||||
const wtV = Array.from({ length: k }, () => Array(d).fill(0));
|
||||
const wtW = Array.from({ length: k }, () => Array(k).fill(0));
|
||||
|
||||
for (let i = 0; i < k; i++) {
|
||||
for (let dim = 0; dim < d; dim++) {
|
||||
let sum = 0;
|
||||
for (let row = 0; row < m; row++) {
|
||||
sum += w[row][i] * matrix[row][dim];
|
||||
}
|
||||
wtV[i][dim] = sum;
|
||||
}
|
||||
for (let j = 0; j < k; j++) {
|
||||
let sum = 0;
|
||||
for (let row = 0; row < m; row++) {
|
||||
sum += w[row][i] * w[row][j];
|
||||
}
|
||||
wtW[i][j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < k; i++) {
|
||||
for (let dim = 0; dim < d; dim++) {
|
||||
let denominator = 0;
|
||||
for (let topic = 0; topic < k; topic++) {
|
||||
denominator += wtW[i][topic] * h[topic][dim];
|
||||
}
|
||||
h[i][dim] *= wtV[i][dim] / (denominator + eps);
|
||||
}
|
||||
}
|
||||
|
||||
const vHt = Array.from({ length: m }, () => Array(k).fill(0));
|
||||
const hHt = Array.from({ length: k }, () => Array(k).fill(0));
|
||||
|
||||
for (let row = 0; row < m; row++) {
|
||||
for (let topic = 0; topic < k; topic++) {
|
||||
let sum = 0;
|
||||
for (let dim = 0; dim < d; dim++) {
|
||||
sum += matrix[row][dim] * h[topic][dim];
|
||||
}
|
||||
vHt[row][topic] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < k; i++) {
|
||||
for (let j = 0; j < k; j++) {
|
||||
let sum = 0;
|
||||
for (let dim = 0; dim < d; dim++) {
|
||||
sum += h[i][dim] * h[j][dim];
|
||||
}
|
||||
hHt[i][j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
for (let row = 0; row < m; row++) {
|
||||
for (let topic = 0; topic < k; topic++) {
|
||||
let denominator = 0;
|
||||
for (let inner = 0; inner < k; inner++) {
|
||||
denominator += w[row][inner] * hHt[inner][topic];
|
||||
}
|
||||
w[row][topic] *= vHt[row][topic] / (denominator + eps);
|
||||
}
|
||||
}
|
||||
|
||||
if (iteration % 10 === 9) {
|
||||
let residualSq = 0;
|
||||
let matrixSq = 0;
|
||||
for (let row = 0; row < m; row++) {
|
||||
for (let dim = 0; dim < d; dim++) {
|
||||
let reconstructed = 0;
|
||||
for (let topic = 0; topic < k; topic++) {
|
||||
reconstructed += w[row][topic] * h[topic][dim];
|
||||
}
|
||||
const diff = matrix[row][dim] - reconstructed;
|
||||
residualSq += diff * diff;
|
||||
matrixSq += matrix[row][dim] * matrix[row][dim];
|
||||
}
|
||||
}
|
||||
|
||||
if (matrixSq > 0 && Math.sqrt(residualSq / matrixSq) < tolerance) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { w, h };
|
||||
}
|
||||
|
||||
function createDeterministicRandom(seed) {
|
||||
let current = seed >>> 0;
|
||||
return () => {
|
||||
current = (1664525 * current + 1013904223) >>> 0;
|
||||
return current / 0xffffffff;
|
||||
};
|
||||
}
|
||||
1837
retrieval/retriever.js
Normal file
1837
retrieval/retriever.js
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user