diff --git a/README.md b/README.md index aa03155..79ffee9 100644 --- a/README.md +++ b/README.md @@ -289,6 +289,7 @@ ST-BME/ ├── retriever.js # 向量候选、图扩散、混合评分、召回 ├── injector.js # 召回结果格式化注入 ├── runtime-state.js # 运行时状态:楼层 hash、dirty 标记、恢复日志 +├── recall-persistence.js # 持久召回记录(message.extra.bme_recall) ├── vector-index.js # 向量索引管理(backend / direct 双模式) ├── embedding.js # 直连 Embedding API 封装 ├── llm.js # 记忆 LLM 请求封装 @@ -311,6 +312,7 @@ ST-BME/ - **图谱数据** → `chat_metadata.st_bme_graph`(跟随聊天保存) - **插件设置** → SillyTavern 的 `extension_settings.st_bme` - **向量索引** → 后端模式走酒馆 API;直连模式存在节点内 +- **召回持久注入** → `chat[x].extra.bme_recall`(消息级) ### 事件挂载 @@ -337,6 +339,44 @@ ST-BME/ 每层内进一步按用途分桶:当前状态 / 情景事件 / 反思锚点 / 规则约束。 +### 持久召回注入(`message.extra.bme_recall`) + +从本版本开始,召回注入支持消息级持久化,存放在对应用户楼层: + +- 路径:`chat[x].extra.bme_recall` +- 主要字段: + - `version` + - `injectionText` + - `selectedNodeIds` + - `recallInput` + - `recallSource` + - `hookName` + - `tokenEstimate` + - `createdAt` / `updatedAt` + - `generationCount`(**仅**在该持久注入被实际用作生成回退时递增) + - `manuallyEdited`(仅表示来源是否为人工编辑) + +注入优先级(避免双重注入): + +1. **本轮有新召回成功**:仅使用新召回注入(临时注入),并覆盖写入目标用户楼层的 `bme_recall`。 +2. **本轮无新召回结果**:仅从“当前生成对应的用户楼层”读取 `bme_recall` 作为回退注入。 +3. **两者都无**:清空注入。 + +> `manuallyEdited` 不参与优先级判断,不会强制覆盖系统召回。 + +消息级 UI: + +- 带有 `bme_recall` 的用户气泡会显示 🧠 badge。 +- 点击 badge 可进行:查看详情 / 手动编辑 / 删除 / 重新召回。 +- 手动编辑后会将 `manuallyEdited=true`。 +- 重新召回成功后会覆盖记录并重置 `manuallyEdited=false`。 +- 删除会移除该楼层的持久召回记录。 + +兼容性说明: + +- 旧聊天(无 `extra` 或无 `bme_recall`)会自动按“无持久记录”处理,不会报错。 +- badge 依赖酒馆消息 DOM 的楼层索引属性;若第三方主题重写消息结构,可能需要额外适配。 + --- ## ⚠️ 已知限制 diff --git a/event-binding.js b/event-binding.js index 8e95f00..197b3ff 100644 --- a/event-binding.js +++ b/event-binding.js @@ -127,12 +127,14 @@ export function onChatChangedController(runtime) { runtime.clearInjectionState(); runtime.clearRecallInputTracking(); runtime.installSendIntentHooks(); + runtime.refreshPersistedRecallMessageUi?.(); } export function onChatLoadedController(runtime) { runtime.syncGraphLoadFromLiveContext({ source: "chat-loaded", }); + runtime.refreshPersistedRecallMessageUi?.(); } export function onMessageSentController(runtime, messageId) { @@ -143,6 +145,7 @@ export function onMessageSentController(runtime, messageId) { if (!message?.is_user) return; runtime.recordRecallSentUserMessage(messageId, message.mes || ""); + runtime.refreshPersistedRecallMessageUi?.(); } export function onMessageDeletedController( @@ -156,16 +159,19 @@ export function onMessageDeletedController( chatLengthOrMessageId, meta, ); + runtime.refreshPersistedRecallMessageUi?.(); } export function onMessageEditedController(runtime, messageId, meta = null) { runtime.invalidateRecallAfterHistoryMutation("消息已编辑"); runtime.scheduleHistoryMutationRecheck("message-edited", messageId, meta); + runtime.refreshPersistedRecallMessageUi?.(); } export function onMessageSwipedController(runtime, messageId, meta = null) { runtime.invalidateRecallAfterHistoryMutation("已切换楼层 swipe"); runtime.scheduleHistoryMutationRecheck("message-swiped", messageId, meta); + runtime.refreshPersistedRecallMessageUi?.(); } export async function onGenerationAfterCommandsController( @@ -183,7 +189,7 @@ export async function onGenerationAfterCommandsController( params, chat, ); - if (!recallOptions?.overrideUserMessage) return; + if (!recallOptions) return; const recallContext = runtime.createGenerationRecallContext({ hookName: "GENERATION_AFTER_COMMANDS", @@ -211,6 +217,11 @@ export async function onGenerationAfterCommandsController( recallContext.hookName, runtime.getGenerationRecallHookStateFromResult(recallResult), ); + + runtime.applyFinalRecallInjectionForGeneration({ + generationType: recallContext.generationType, + freshRecallResult: recallResult, + }); } export async function onBeforeCombinePromptsController(runtime) { @@ -244,6 +255,11 @@ export async function onBeforeCombinePromptsController(runtime) { recallContext.hookName, runtime.getGenerationRecallHookStateFromResult(recallResult), ); + + runtime.applyFinalRecallInjectionForGeneration({ + generationType: recallContext.generationType, + freshRecallResult: recallResult, + }); } export function onMessageReceivedController(runtime) { @@ -280,4 +296,5 @@ export function onMessageReceivedController(runtime) { }); }); } + runtime.refreshPersistedRecallMessageUi?.(); } diff --git a/index.js b/index.js index 05a76e7..3d1f405 100644 --- a/index.js +++ b/index.js @@ -186,6 +186,16 @@ import { resolveDirtyFloorFromMutationMeta, rollbackAffectedJournals, } from "./chat-history.js"; +import { + buildPersistedRecallRecord, + bumpPersistedRecallGenerationCount, + markPersistedRecallManualEdit, + readPersistedRecallFromUserMessage, + removePersistedRecallFromUserMessage, + resolveFinalRecallInjectionSource, + resolveGenerationTargetUserMessageIndex, + writePersistedRecallToUserMessage, +} from "./recall-persistence.js"; // 操控面板模块(动态加载,防止加载失败崩溃整个扩展) let _panelModule = null; @@ -447,6 +457,7 @@ let skipBeforeCombineRecallUntil = 0; let lastPreGenerationRecallKey = ""; let lastPreGenerationRecallAt = 0; const generationRecallTransactions = new Map(); +let persistedRecallUiRefreshTimer = null; const GENERATION_RECALL_TRANSACTION_TTL_MS = 15000; const stageNoticeHandles = { extraction: null, @@ -943,6 +954,349 @@ function recordRecallSentUserMessage(messageId, text, source = "message-sent") { } } +function getMessageRecallRecord(messageIndex) { + const chat = getContext()?.chat; + return readPersistedRecallFromUserMessage(chat, messageIndex); +} + +function persistRecallInjectionRecord({ + recallInput = {}, + result = {}, + injectionText = "", + tokenEstimate = 0, +} = {}) { + const chat = getContext()?.chat; + if (!Array.isArray(chat)) return null; + + const generationType = String(recallInput?.generationType || "normal").trim() || "normal"; + let resolvedTargetIndex = Number.isFinite(recallInput?.targetUserMessageIndex) + ? recallInput.targetUserMessageIndex + : resolveGenerationTargetUserMessageIndex(chat, { generationType }); + + if ( + !Number.isFinite(resolvedTargetIndex) && + Number.isFinite(lastRecallSentUserMessage?.messageId) && + chat[lastRecallSentUserMessage.messageId]?.is_user + ) { + resolvedTargetIndex = lastRecallSentUserMessage.messageId; + } + + if (!Number.isFinite(resolvedTargetIndex)) return null; + + const record = buildPersistedRecallRecord( + { + injectionText, + selectedNodeIds: result?.selectedNodeIds || [], + recallInput: String(recallInput?.userMessage || ""), + recallSource: String(recallInput?.source || ""), + hookName: String(recallInput?.hookName || ""), + tokenEstimate, + manuallyEdited: false, + }, + readPersistedRecallFromUserMessage(chat, resolvedTargetIndex), + ); + if (!writePersistedRecallToUserMessage(chat, resolvedTargetIndex, record)) { + return null; + } + + triggerChatMetadataSave(getContext(), { immediate: false }); + return { + index: resolvedTargetIndex, + record, + }; +} + +function removeMessageRecallRecord(messageIndex) { + const chat = getContext()?.chat; + if (!Array.isArray(chat)) return false; + const removed = removePersistedRecallFromUserMessage(chat, messageIndex); + if (removed) { + triggerChatMetadataSave(getContext(), { immediate: false }); + } + return removed; +} + +function editMessageRecallRecord(messageIndex, nextInjectionText) { + const chat = getContext()?.chat; + if (!Array.isArray(chat)) return null; + const current = readPersistedRecallFromUserMessage(chat, messageIndex); + if (!current) return null; + + const normalizedText = normalizeRecallInputText(nextInjectionText); + if (!normalizedText) return null; + const nowIso = new Date().toISOString(); + const nextRecord = { + ...current, + injectionText: normalizedText, + tokenEstimate: estimateTokens(normalizedText), + updatedAt: nowIso, + }; + if (!writePersistedRecallToUserMessage(chat, messageIndex, nextRecord)) { + return null; + } + const edited = markPersistedRecallManualEdit(chat, messageIndex, true, nowIso); + if (!edited) return null; + + triggerChatMetadataSave(getContext(), { immediate: false }); + return edited; +} + +function applyFinalRecallInjectionForGeneration({ + generationType = "normal", + freshRecallResult = null, +} = {}) { + const chat = getContext()?.chat; + if (!Array.isArray(chat)) { + applyModuleInjectionPrompt("", getSettings()); + return { source: "none", targetUserMessageIndex: null, usedText: "" }; + } + + let targetUserMessageIndex = resolveGenerationTargetUserMessageIndex(chat, { + generationType, + }); + if ( + !Number.isFinite(targetUserMessageIndex) && + Number.isFinite(lastRecallSentUserMessage?.messageId) && + chat[lastRecallSentUserMessage.messageId]?.is_user + ) { + targetUserMessageIndex = lastRecallSentUserMessage.messageId; + } + + const persistedRecord = Number.isFinite(targetUserMessageIndex) + ? readPersistedRecallFromUserMessage(chat, targetUserMessageIndex) + : null; + const resolved = resolveFinalRecallInjectionSource({ + freshRecallResult, + persistedRecord, + }); + + if (resolved.source === "persisted") { + applyModuleInjectionPrompt(resolved.injectionText || "", getSettings()); + } else if (resolved.source === "none") { + applyModuleInjectionPrompt("", getSettings()); + } + + if (resolved.source === "persisted" && Number.isFinite(targetUserMessageIndex)) { + bumpPersistedRecallGenerationCount(chat, targetUserMessageIndex); + triggerChatMetadataSave(getContext(), { immediate: false }); + } + + if (resolved.source === "fresh") { + runtimeStatus = createUiStatus( + "召回已注入", + "本轮已使用最新召回结果", + "success", + ); + } else if (resolved.source === "persisted") { + lastInjectionContent = resolved.injectionText || ""; + runtimeStatus = createUiStatus("召回回退", "已使用消息楼层持久化注入", "info"); + } else { + lastInjectionContent = ""; + runtimeStatus = createUiStatus("待命", "当前无有效注入内容", "idle"); + } + refreshPanelLiveState(); + schedulePersistedRecallMessageUiRefresh(); + + return { + source: resolved.source, + isFallback: resolved.source === "persisted", + targetUserMessageIndex, + usedText: resolved.injectionText || "", + }; +} + +function resolveMessageIndexFromElement(messageElement, fallbackIndex = null) { + if (!messageElement) return Number.isFinite(fallbackIndex) ? fallbackIndex : null; + + const candidates = [ + messageElement.getAttribute?.("mesid"), + messageElement.getAttribute?.("data-mesid"), + messageElement.getAttribute?.("data-message-id"), + messageElement.dataset?.mesid, + messageElement.dataset?.messageId, + ]; + for (const candidate of candidates) { + const parsed = Number.parseInt(candidate, 10); + if (Number.isFinite(parsed)) return parsed; + } + + return Number.isFinite(fallbackIndex) ? fallbackIndex : null; +} + +function buildMessageRecallBadgeTitle(messageIndex, record) { + const lines = [`ST-BME 持久召回 · 楼层 ${messageIndex}`]; + if (record?.manuallyEdited) { + lines.push("来源:手动编辑"); + } + if (Number.isFinite(record?.generationCount)) { + lines.push(`已回退使用:${record.generationCount} 次`); + } + if (record?.updatedAt) { + lines.push(`更新:${record.updatedAt}`); + } + return lines.join("\n"); +} + +function refreshPersistedRecallMessageUi() { + const context = getContext(); + const chat = context?.chat; + if (!Array.isArray(chat) || typeof document?.getElementById !== "function") return; + + const chatRoot = document.getElementById("chat"); + if (!chatRoot) return; + + chatRoot + .querySelectorAll?.(".st-bme-recall-badge") + ?.forEach?.((badge) => badge.remove()); + + const messageElements = Array.from(chatRoot.querySelectorAll(".mes")); + for (let fallbackIndex = 0; fallbackIndex < messageElements.length; fallbackIndex++) { + const messageElement = messageElements[fallbackIndex]; + const messageIndex = resolveMessageIndexFromElement(messageElement, fallbackIndex); + if (!Number.isFinite(messageIndex)) continue; + + const message = chat[messageIndex]; + if (!message?.is_user) continue; + + const record = readPersistedRecallFromUserMessage(chat, messageIndex); + if (!record?.injectionText) continue; + + const badge = document.createElement("button"); + badge.type = "button"; + badge.className = "st-bme-recall-badge"; + badge.textContent = "🧠"; + badge.title = buildMessageRecallBadgeTitle(messageIndex, record); + badge.dataset.messageIndex = String(messageIndex); + badge.style.marginInlineStart = "6px"; + badge.style.padding = "0 4px"; + badge.style.borderRadius = "10px"; + badge.style.border = "1px solid var(--SmartThemeBorderColor, #666)"; + badge.style.background = "var(--SmartThemeQuoteColor, rgba(120, 120, 120, 0.18))"; + badge.style.cursor = "pointer"; + badge.style.fontSize = "12px"; + badge.style.lineHeight = "1.4"; + + badge.addEventListener("click", (event) => { + event.preventDefault(); + event.stopPropagation(); + void onMessageRecallBadgeClick(messageIndex); + }); + + const anchor = + messageElement.querySelector(".mes_buttons") || + messageElement.querySelector(".mes_title") || + messageElement.querySelector(".mes_header") || + messageElement; + anchor.appendChild(badge); + } +} + +function schedulePersistedRecallMessageUiRefresh(delayMs = 16) { + clearTimeout(persistedRecallUiRefreshTimer); + persistedRecallUiRefreshTimer = setTimeout(() => { + persistedRecallUiRefreshTimer = null; + refreshPersistedRecallMessageUi(); + }, Math.max(0, Number.parseInt(delayMs, 10) || 0)); +} + +function showMessageRecallDetail(messageIndex, record) { + const details = [ + `楼层: ${messageIndex}`, + `来源: ${record.recallSource || "unknown"}`, + `Hook: ${record.hookName || "-"}`, + `tokenEstimate: ${record.tokenEstimate || 0}`, + `generationCount: ${record.generationCount || 0}`, + `manuallyEdited: ${record.manuallyEdited ? "true" : "false"}`, + `updatedAt: ${record.updatedAt || "-"}`, + "", + "注入内容:", + record.injectionText || "(empty)", + ].join("\n"); + globalThis.alert?.(details); +} + +async function rerunRecallForMessage(messageIndex) { + const chat = getContext()?.chat; + const message = Array.isArray(chat) ? chat[messageIndex] : null; + if (!message?.is_user) { + toastr.info("仅用户消息支持重新召回"); + return null; + } + + const userMessage = normalizeRecallInputText(message.mes || ""); + if (!userMessage) { + toastr.info("该楼层内容为空,无法重新召回"); + return null; + } + + const result = await runRecall({ + overrideUserMessage: userMessage, + overrideSource: "message-floor-rerecall", + overrideSourceLabel: `用户楼层 ${messageIndex}`, + generationType: "history", + targetUserMessageIndex: messageIndex, + includeSyntheticUserMessage: false, + hookName: "MESSAGE_RECALL_BADGE_RERUN", + }); + applyFinalRecallInjectionForGeneration({ + generationType: "history", + freshRecallResult: result, + }); + return result; +} + +async function onMessageRecallBadgeClick(messageIndex) { + const record = getMessageRecallRecord(messageIndex); + if (!record) { + toastr.info("该楼层暂无持久召回记录"); + schedulePersistedRecallMessageUiRefresh(); + return; + } + + const choiceRaw = globalThis.prompt?.( + [ + `ST-BME 持久召回(楼层 ${messageIndex})`, + "1 查看详情", + "2 手动编辑", + "3 删除", + "4 重新召回", + "请输入序号:", + ].join("\n"), + "1", + ); + const choice = String(choiceRaw || "").trim().toLowerCase(); + if (!choice) return; + + if (choice === "1" || choice === "view" || choice === "detail") { + showMessageRecallDetail(messageIndex, record); + } else if (choice === "2" || choice === "edit") { + const nextText = globalThis.prompt?.( + `编辑楼层 ${messageIndex} 的持久召回注入文本:`, + record.injectionText || "", + ); + if (nextText !== null && nextText !== undefined) { + const edited = editMessageRecallRecord(messageIndex, nextText); + if (edited) { + toastr.success("已保存手动编辑并标记 manuallyEdited=true"); + } else { + toastr.warning("编辑失败:注入文本不能为空"); + } + } + } else if (choice === "3" || choice === "delete") { + const confirmed = globalThis.confirm?.(`确认删除楼层 ${messageIndex} 的持久召回注入?`); + if (confirmed && removeMessageRecallRecord(messageIndex)) { + toastr.success("已删除持久召回注入"); + } + } else if (choice === "4" || choice === "reroll" || choice === "recall") { + const rerunResult = await rerunRecallForMessage(messageIndex); + if (rerunResult?.status === "completed") { + toastr.success("重新召回完成,已覆盖持久召回记录"); + } + } + + schedulePersistedRecallMessageUiRefresh(); +} + function getSendTextareaValue() { return String(document.getElementById("send_textarea")?.value ?? ""); } @@ -3100,9 +3454,32 @@ function buildGenerationAfterCommandsRecallInput(type, params = {}, chat) { return null; } - return generationType === "normal" - ? buildNormalGenerationRecallInput(chat) - : buildHistoryGenerationRecallInput(chat); + const targetUserMessageIndex = resolveGenerationTargetUserMessageIndex(chat, { + generationType, + }); + if (!Number.isFinite(targetUserMessageIndex)) { + return { + generationType, + targetUserMessageIndex: null, + }; + } + + if (generationType !== "normal") { + const historyInput = buildHistoryGenerationRecallInput(chat); + if (!historyInput) { + return { + generationType, + targetUserMessageIndex, + }; + } + return { + ...historyInput, + generationType, + targetUserMessageIndex, + }; + } + + return buildNormalGenerationRecallInput(chat); } function buildNormalGenerationRecallInput(chat) { @@ -3110,6 +3487,7 @@ function buildNormalGenerationRecallInput(chat) { const tailUserText = lastNonSystemMessage?.is_user ? normalizeRecallInputText(lastNonSystemMessage?.mes || "") : ""; + const targetUserMessageIndex = resolveGenerationTargetUserMessageIndex(chat, { generationType: "normal" }); const textareaText = normalizeRecallInputText( pendingRecallSendIntent.text || getSendTextareaValue(), ); @@ -3118,6 +3496,8 @@ function buildNormalGenerationRecallInput(chat) { return { overrideUserMessage: userMessage, + generationType: "normal", + targetUserMessageIndex, overrideSource: tailUserText ? "chat-tail-user" : "send-intent", overrideSourceLabel: tailUserText ? "当前用户楼层" : "发送意图", includeSyntheticUserMessage: !tailUserText, @@ -3129,20 +3509,33 @@ function buildHistoryGenerationRecallInput(chat) { getLatestUserChatMessage(chat)?.mes || lastRecallSentUserMessage.text, ); if (!latestUserText) return null; + const targetUserMessageIndex = resolveGenerationTargetUserMessageIndex(chat, { + generationType: "history", + }); return { overrideUserMessage: latestUserText, - overrideSource: "chat-last-user", - overrideSourceLabel: "历史最后用户楼层", + generationType: "history", + targetUserMessageIndex, + overrideSource: Number.isFinite(targetUserMessageIndex) + ? "chat-last-user" + : "chat-last-user-missing", + overrideSourceLabel: Number.isFinite(targetUserMessageIndex) ? "历史最后用户楼层" : "历史用户楼层缺失", includeSyntheticUserMessage: false, }; } function buildPreGenerationRecallKey(type, options = {}) { + const targetUserMessageIndex = Number.isFinite(options.targetUserMessageIndex) + ? options.targetUserMessageIndex + : "none"; + const seedText = + options.overrideUserMessage || options.userMessage || `@target:${targetUserMessageIndex}`; + return [ getCurrentChatId(), String(type || "normal").trim() || "normal", - hashRecallInput(options.overrideUserMessage || ""), + hashRecallInput(seedText), ].join(":"); } @@ -4230,6 +4623,7 @@ function applyRecallInjection(settings, recallInput, recentMessages, result) { recentMessages, result, { + persistRecallInjectionRecord, applyModuleInjectionPrompt, console, estimateTokens, @@ -4361,6 +4755,7 @@ function onChatChanged() { dismissAllStageNotices, getPendingHistoryRecoveryTimer: () => pendingHistoryRecoveryTimer, installSendIntentHooks, + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, setLastPreGenerationRecallAt: (value) => { lastPreGenerationRecallAt = value; }, @@ -4382,6 +4777,7 @@ function onChatChanged() { function onChatLoaded() { return onChatLoadedController({ + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, syncGraphLoadFromLiveContext, }); } @@ -4391,6 +4787,7 @@ function onMessageSent(messageId) { { getContext, recordRecallSentUserMessage, + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, }, messageId, ); @@ -4400,6 +4797,7 @@ function onMessageDeleted(chatLengthOrMessageId, meta = null) { return onMessageDeletedController( { invalidateRecallAfterHistoryMutation, + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, scheduleHistoryMutationRecheck, }, chatLengthOrMessageId, @@ -4411,6 +4809,7 @@ function onMessageEdited(messageId, meta = null) { return onMessageEditedController( { invalidateRecallAfterHistoryMutation, + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, scheduleHistoryMutationRecheck, }, messageId, @@ -4422,6 +4821,7 @@ function onMessageSwiped(messageId, meta = null) { return onMessageSwipedController( { invalidateRecallAfterHistoryMutation, + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, scheduleHistoryMutationRecheck, }, messageId, @@ -4432,6 +4832,7 @@ function onMessageSwiped(messageId, meta = null) { async function onGenerationAfterCommands(type, params = {}, dryRun = false) { return await onGenerationAfterCommandsController( { + applyFinalRecallInjectionForGeneration, buildGenerationAfterCommandsRecallInput, createGenerationRecallContext, getContext, @@ -4447,6 +4848,7 @@ async function onGenerationAfterCommands(type, params = {}, dryRun = false) { async function onBeforeCombinePrompts() { return await onBeforeCombinePromptsController({ + applyFinalRecallInjectionForGeneration, buildHistoryGenerationRecallInput, buildNormalGenerationRecallInput, createGenerationRecallContext, @@ -4473,6 +4875,7 @@ function onMessageReceived() { notifyExtractionIssue, queueMicrotask, runExtraction, + refreshPersistedRecallMessageUi: schedulePersistedRecallMessageUiRefresh, setPendingRecallSendIntent: (record) => { pendingRecallSendIntent = record; }, @@ -4829,5 +5232,6 @@ async function onReembedDirect() { updateSettings: updateModuleSettings, }); + schedulePersistedRecallMessageUiRefresh(120); console.log("[ST-BME] 初始化完成"); })(); diff --git a/recall-controller.js b/recall-controller.js index 5775b39..b969840 100644 --- a/recall-controller.js +++ b/recall-controller.js @@ -57,13 +57,17 @@ export function resolveRecallInputController( runtime, ) { const overrideText = runtime.normalizeRecallInputText( - override?.userMessage || "", + override?.userMessage || override?.overrideUserMessage || "", ); if (overrideText) { return { userMessage: overrideText, - source: String(override?.source || "override"), - sourceLabel: String(override?.sourceLabel || "发送前拦截"), + generationType: String(override?.generationType || "normal"), + targetUserMessageIndex: Number.isFinite(override?.targetUserMessageIndex) ? override.targetUserMessageIndex : null, + source: String(override?.source || override?.overrideSource || "override"), + sourceLabel: String( + override?.sourceLabel || override?.overrideSourceLabel || "发送前拦截", + ), recentMessages: runtime.buildRecallRecentMessages( chat, recentContextMessageLimit, @@ -115,6 +119,8 @@ export function resolveRecallInputController( return { userMessage, + generationType: "normal", + targetUserMessageIndex: null, source, sourceLabel: runtime.getRecallUserMessageSourceLabel(source), recentMessages: runtime.buildRecallRecentMessages( @@ -149,6 +155,7 @@ export function applyRecallInjectionController( runtime.console.log( `[ST-BME] 注入 ${tokens} 估算 tokens, Core=${result.stats.coreCount}, Recall=${result.stats.recallCount}`, ); + runtime.persistRecallInjectionRecord?.({ recallInput, result, injectionText, tokenEstimate: tokens }); } const injectionTransport = runtime.applyModuleInjectionPrompt( @@ -372,10 +379,16 @@ export async function runRecallController(runtime, options = {}) { options: runtime.buildRecallRetrieveOptions(settings, context), }); - runtime.applyRecallInjection(settings, recallInput, recentMessages, result); + const applied = runtime.applyRecallInjection( + settings, + recallInput, + recentMessages, + result, + ); return runtime.createRecallRunResult("completed", { reason: "召回完成", selectedNodeIds: result.selectedNodeIds || [], + injectionText: applied?.injectionText || "", }); } catch (e) { if (runtime.isAbortError(e)) { diff --git a/recall-persistence.js b/recall-persistence.js new file mode 100644 index 0000000..1eaf898 --- /dev/null +++ b/recall-persistence.js @@ -0,0 +1,184 @@ +// ST-BME: 持久化召回记录纯函数 + +export const BME_RECALL_EXTRA_KEY = "bme_recall"; +export const BME_RECALL_VERSION = 1; + +function toIsoString(value) { + if (typeof value === "string" && value.trim()) return value; + return new Date().toISOString(); +} + +function cloneStringArray(value) { + return Array.isArray(value) + ? value + .map((item) => String(item || "").trim()) + .filter(Boolean) + : []; +} + +function cloneRecord(value) { + if (!value || typeof value !== "object" || Array.isArray(value)) return null; + return { ...value }; +} + +export function readPersistedRecallFromUserMessage(chat, userMessageIndex) { + if (!Array.isArray(chat) || !Number.isFinite(userMessageIndex)) return null; + const message = chat[userMessageIndex]; + const raw = message?.extra?.[BME_RECALL_EXTRA_KEY]; + const record = cloneRecord(raw); + if (!record) return null; + + const injectionText = String(record.injectionText || "").trim(); + if (!injectionText) return null; + + return { + version: Number.isFinite(Number(record.version)) + ? Number(record.version) + : BME_RECALL_VERSION, + injectionText, + selectedNodeIds: cloneStringArray(record.selectedNodeIds), + recallInput: String(record.recallInput || ""), + recallSource: String(record.recallSource || ""), + hookName: String(record.hookName || ""), + tokenEstimate: Number.isFinite(Number(record.tokenEstimate)) + ? Number(record.tokenEstimate) + : 0, + createdAt: toIsoString(record.createdAt), + updatedAt: toIsoString(record.updatedAt), + generationCount: Math.max(0, Number.parseInt(record.generationCount, 10) || 0), + manuallyEdited: Boolean(record.manuallyEdited), + }; +} + +export function buildPersistedRecallRecord(payload = {}, existingRecord = null) { + const nowIso = toIsoString(payload.nowIso); + const previous = cloneRecord(existingRecord) || {}; + const injectionText = String(payload.injectionText || "").trim(); + + return { + version: BME_RECALL_VERSION, + injectionText, + selectedNodeIds: cloneStringArray(payload.selectedNodeIds), + recallInput: String(payload.recallInput || ""), + recallSource: String(payload.recallSource || ""), + hookName: String(payload.hookName || ""), + tokenEstimate: Number.isFinite(Number(payload.tokenEstimate)) + ? Number(payload.tokenEstimate) + : 0, + createdAt: toIsoString(previous.createdAt || nowIso), + updatedAt: nowIso, + generationCount: 0, + manuallyEdited: Boolean(payload.manuallyEdited), + }; +} + +export function writePersistedRecallToUserMessage(chat, userMessageIndex, record) { + if (!Array.isArray(chat) || !Number.isFinite(userMessageIndex)) return false; + const message = chat[userMessageIndex]; + if (!message || !message.is_user) return false; + + const normalized = cloneRecord(record); + if (!normalized || !String(normalized.injectionText || "").trim()) return false; + + message.extra ||= {}; + message.extra[BME_RECALL_EXTRA_KEY] = normalized; + return true; +} + +export function removePersistedRecallFromUserMessage(chat, userMessageIndex) { + if (!Array.isArray(chat) || !Number.isFinite(userMessageIndex)) return false; + const message = chat[userMessageIndex]; + if (!message?.extra || typeof message.extra !== "object") return false; + if (!(BME_RECALL_EXTRA_KEY in message.extra)) return false; + delete message.extra[BME_RECALL_EXTRA_KEY]; + return true; +} + +export function markPersistedRecallManualEdit( + chat, + userMessageIndex, + manuallyEdited = true, + nowIso = new Date().toISOString(), +) { + const current = readPersistedRecallFromUserMessage(chat, userMessageIndex); + if (!current) return null; + const nextRecord = { + ...current, + manuallyEdited: Boolean(manuallyEdited), + updatedAt: toIsoString(nowIso), + }; + if (!writePersistedRecallToUserMessage(chat, userMessageIndex, nextRecord)) { + return null; + } + return nextRecord; +} + +export function bumpPersistedRecallGenerationCount(chat, userMessageIndex) { + const current = readPersistedRecallFromUserMessage(chat, userMessageIndex); + if (!current) return null; + const nextRecord = { + ...current, + generationCount: Math.max(0, Number(current.generationCount || 0)) + 1, + }; + if (!writePersistedRecallToUserMessage(chat, userMessageIndex, nextRecord)) { + return null; + } + return nextRecord; +} + +export function resolveGenerationTargetUserMessageIndex( + chat, + { generationType = "normal" } = {}, +) { + if (!Array.isArray(chat) || chat.length === 0) return null; + + const normalizedType = String(generationType || "normal").trim() || "normal"; + + if (normalizedType === "normal") { + for (let index = chat.length - 1; index >= 0; index--) { + const message = chat[index]; + if (message?.is_system) continue; + return message?.is_user ? index : null; + } + return null; + } + + for (let index = chat.length - 1; index >= 0; index--) { + if (chat[index]?.is_user) return index; + } + + return null; +} + +export function resolveFinalRecallInjectionSource({ + freshRecallResult = null, + persistedRecord = null, +} = {}) { + const freshInjection = String(freshRecallResult?.injectionText || "").trim(); + if ( + freshRecallResult?.status === "completed" && + freshRecallResult?.didRecall && + freshInjection + ) { + return { + source: "fresh", + injectionText: freshInjection, + record: null, + }; + } + + const persistedInjection = String(persistedRecord?.injectionText || "").trim(); + if (persistedInjection) { + return { + source: "persisted", + injectionText: persistedInjection, + record: persistedRecord, + }; + } + + return { + source: "none", + injectionText: "", + record: null, + }; +} diff --git a/tests/p0-regressions.mjs b/tests/p0-regressions.mjs index 2add352..122844d 100644 --- a/tests/p0-regressions.mjs +++ b/tests/p0-regressions.mjs @@ -58,6 +58,16 @@ import { onGenerationAfterCommandsController, } from "../event-binding.js"; import { onRerollController } from "../extraction-controller.js"; +import { + buildPersistedRecallRecord, + readPersistedRecallFromUserMessage, + removePersistedRecallFromUserMessage, + resolveFinalRecallInjectionSource, + resolveGenerationTargetUserMessageIndex, + writePersistedRecallToUserMessage, + bumpPersistedRecallGenerationCount, + markPersistedRecallManualEdit, +} from "../recall-persistence.js"; const extensionsShimSource = [ "export const extension_settings = globalThis.__p0ExtensionSettings || {};", @@ -290,6 +300,7 @@ function createGenerationRecallHarness() { }), chat: [], runRecallCalls: [], + applyFinalCalls: [], createRecallInputRecord, createRecallRunResult, hashRecallInput, @@ -310,6 +321,24 @@ function createGenerationRecallHarness() { GRAPH_PERSISTENCE_META_KEY, onBeforeCombinePromptsController, onGenerationAfterCommandsController, + readPersistedRecallFromUserMessage: () => null, + resolveFinalRecallInjectionSource: ({ freshRecallResult = null } = {}) => ({ + source: freshRecallResult?.didRecall ? "fresh" : "none", + injectionText: String(freshRecallResult?.injectionText || ""), + record: null, + }), + bumpPersistedRecallGenerationCount: () => null, + applyModuleInjectionPrompt: () => ({}), + getSettings: () => ({}), + triggerChatMetadataSave: () => "debounced", + refreshPanelLiveState: () => {}, + resolveGenerationTargetUserMessageIndex: (chat = [], { generationType } = {}) => { + const normalized = String(generationType || "normal"); + if (!Array.isArray(chat) || chat.length === 0) return null; + if (normalized === "normal") return chat[chat.length - 1]?.is_user ? chat.length - 1 : null; + for (let index = chat.length - 1; index >= 0; index--) if (chat[index]?.is_user) return index; + return null; + }, }; vm.createContext(context); vm.runInContext( @@ -317,6 +346,13 @@ function createGenerationRecallHarness() { context, { filename: indexPath }, ); + context.applyFinalRecallInjectionForGeneration = (payload = {}) => { + context.applyFinalCalls.push({ ...payload }); + return { + source: "fresh", + targetUserMessageIndex: null, + }; + }; context.runRecall = async (options = {}) => { context.runRecallCalls.push({ ...options }); return { status: "completed", didRecall: true, ok: true }; @@ -325,6 +361,138 @@ function createGenerationRecallHarness() { }); } +function createMessageRecallUiHarness() { + return fs.readFile(indexPath, "utf8").then((source) => { + const start = source.indexOf("function getMessageRecallRecord(messageIndex) {"); + const end = source.indexOf("function getSendTextareaValue() {"); + if (start < 0 || end < 0 || end <= start) { + throw new Error("无法从 index.js 提取消息级召回 UI 定义"); + } + + const snippet = source.slice(start, end).replace(/^export\s+/gm, ""); + const chat = [ + { + is_user: true, + mes: "u0", + extra: { + bme_recall: { + version: 1, + injectionText: "persisted-memory", + selectedNodeIds: ["n1"], + recallInput: "u0", + recallSource: "chat-last-user", + hookName: "GENERATION_AFTER_COMMANDS", + tokenEstimate: 16, + createdAt: "2026-01-01T00:00:00.000Z", + updatedAt: "2026-01-01T00:00:00.000Z", + generationCount: 0, + manuallyEdited: false, + }, + }, + }, + ]; + + const badgeHost = { + children: [], + appendChild(child) { + child.parent = this; + this.children.push(child); + }, + }; + const messageEl = { + dataset: {}, + getAttribute(name) { + if (name === "mesid" || name === "data-mesid") return "0"; + return null; + }, + querySelector(selector) { + if (selector === ".mes_buttons") return badgeHost; + return null; + }, + }; + const chatRoot = { + querySelectorAll(selector) { + if (selector === ".mes") return [messageEl]; + if (selector === ".st-bme-recall-badge") { + return badgeHost.children.filter( + (child) => child.className === "st-bme-recall-badge", + ); + } + return []; + }, + }; + + const context = { + console, + Date, + clearTimeout, + setTimeout, + result: null, + persistedRecallUiRefreshTimer: null, + lastInjectionContent: "", + lastRecallSentUserMessage: createRecallInputRecord(), + runtimeStatus: createUiStatus("待命", "", "idle"), + getContext: () => ({ chat }), + document: { + getElementById(id) { + return id === "chat" ? chatRoot : null; + }, + createElement() { + return { + className: "", + textContent: "", + dataset: {}, + style: {}, + listeners: {}, + parent: null, + addEventListener(type, handler) { + this.listeners[type] = handler; + }, + remove() { + if (!this.parent?.children) return; + this.parent.children = this.parent.children.filter((item) => item !== this); + }, + }; + }, + }, + readPersistedRecallFromUserMessage, + writePersistedRecallToUserMessage, + removePersistedRecallFromUserMessage, + markPersistedRecallManualEdit, + bumpPersistedRecallGenerationCount, + buildPersistedRecallRecord, + resolveGenerationTargetUserMessageIndex, + resolveFinalRecallInjectionSource, + normalizeRecallInputText, + estimateTokens: (text = "") => String(text || "").length, + triggerChatMetadataSave: () => "debounced", + getSettings: () => ({}), + applyModuleInjectionPrompt: () => ({}), + createUiStatus, + refreshPanelLiveState: () => {}, + runRecall: async () => ({ status: "completed", didRecall: true, injectionText: "fresh" }), + applyFinalRecallInjectionForGeneration: () => ({ source: "fresh" }), + toastr: { info() {}, success() {}, warning() {}, error() {} }, + promptResponses: [], + prompt(defaultText = "") { + return context.promptResponses.length > 0 ? context.promptResponses.shift() : defaultText; + }, + confirm: () => true, + alertMessages: [], + alert(message) { + context.alertMessages.push(String(message || "")); + }, + }; + vm.createContext(context); + vm.runInContext( + `${snippet}\nresult = { refreshPersistedRecallMessageUi, onMessageRecallBadgeClick };`, + context, + { filename: indexPath }, + ); + return { context, chat, badgeHost }; + }); +} + function createRerollHarness() { return fs.readFile(indexPath, "utf8").then((source) => { const rollbackStart = source.indexOf("async function rollbackGraphForReroll("); @@ -1419,6 +1587,130 @@ async function testGenerationRecallSkippedStateDoesNotLoopToBeforeCombine() { assert.equal(transaction.hookStates.GENERATION_AFTER_COMMANDS, "skipped"); } +async function testGenerationRecallAppliesFinalInjectionOncePerTransaction() { + const harness = await createGenerationRecallHarness(); + harness.chat = [{ is_user: true, mes: "同一轮仅一次最终注入" }]; + + await harness.result.onGenerationAfterCommands("normal", {}, false); + await harness.result.onBeforeCombinePrompts(); + + assert.equal(harness.applyFinalCalls.length, 1); + assert.equal(harness.applyFinalCalls[0].generationType, "normal"); +} + +async function testPersistentRecallDataLayerLifecycleAndCompatibility() { + const chat = [ + { is_user: true, mes: "u0" }, + { is_user: false, mes: "a1" }, + { is_user: true, mes: "u2" }, + ]; + + const record = buildPersistedRecallRecord({ + injectionText: "fresh-memory", + selectedNodeIds: ["n1", "n2"], + recallInput: "u2", + recallSource: "chat-last-user", + hookName: "GENERATION_AFTER_COMMANDS", + tokenEstimate: 24, + manuallyEdited: false, + nowIso: "2026-01-01T00:00:00.000Z", + }); + assert.equal(writePersistedRecallToUserMessage(chat, 2, record), true); + + const loaded = readPersistedRecallFromUserMessage(chat, 2); + assert.ok(loaded); + assert.equal(loaded.injectionText, "fresh-memory"); + assert.equal(loaded.generationCount, 0); + assert.equal(loaded.manuallyEdited, false); + + chat[2].mes = "u2 edited"; + assert.equal(readPersistedRecallFromUserMessage(chat, 2)?.injectionText, "fresh-memory"); + + const bumped = bumpPersistedRecallGenerationCount(chat, 2); + assert.equal(bumped?.generationCount, 1); + + const edited = markPersistedRecallManualEdit( + chat, + 2, + true, + "2026-01-01T00:00:01.000Z", + ); + assert.equal(edited?.manuallyEdited, true); + assert.equal(edited?.updatedAt, "2026-01-01T00:00:01.000Z"); + + const overwrite = buildPersistedRecallRecord( + { + injectionText: "system-rerecall", + selectedNodeIds: ["n3"], + recallInput: "u2 edited", + recallSource: "message-floor-rerecall", + hookName: "MESSAGE_RECALL_BADGE_RERUN", + tokenEstimate: 30, + manuallyEdited: false, + nowIso: "2026-01-01T00:00:02.000Z", + }, + readPersistedRecallFromUserMessage(chat, 2), + ); + assert.equal(writePersistedRecallToUserMessage(chat, 2, overwrite), true); + const overwritten = readPersistedRecallFromUserMessage(chat, 2); + assert.equal(overwritten?.manuallyEdited, false); + assert.equal(overwritten?.injectionText, "system-rerecall"); + + assert.equal(removePersistedRecallFromUserMessage(chat, 2), true); + assert.equal(readPersistedRecallFromUserMessage(chat, 2), null); + assert.equal(readPersistedRecallFromUserMessage([{ is_user: true, mes: "legacy" }], 0), null); +} + +async function testPersistentRecallSourceResolutionAndTargetRouting() { + const chat = [ + { is_user: true, mes: "u0" }, + { is_user: false, mes: "a1" }, + { is_user: true, mes: "u2" }, + { is_user: false, mes: "a3" }, + ]; + + assert.equal(resolveGenerationTargetUserMessageIndex(chat, { generationType: "normal" }), null); + assert.equal(resolveGenerationTargetUserMessageIndex(chat, { generationType: "continue" }), 2); + + const withTailUser = [...chat, { is_user: true, mes: "u4" }]; + assert.equal(resolveGenerationTargetUserMessageIndex(withTailUser, { generationType: "normal" }), 4); + + const freshWins = resolveFinalRecallInjectionSource({ + freshRecallResult: { status: "completed", didRecall: true, injectionText: "fresh" }, + persistedRecord: { injectionText: "persisted" }, + }); + assert.equal(freshWins.source, "fresh"); + assert.equal(freshWins.injectionText, "fresh"); + + const fallback = resolveFinalRecallInjectionSource({ + freshRecallResult: { status: "skipped", didRecall: false, injectionText: "" }, + persistedRecord: { injectionText: "persisted" }, + }); + assert.equal(fallback.source, "persisted"); + assert.equal(fallback.injectionText, "persisted"); +} + +async function testMessageRecallUiBadgeEntryPoints() { + const { context, chat, badgeHost } = await createMessageRecallUiHarness(); + context.result.refreshPersistedRecallMessageUi(); + assert.equal(badgeHost.children.length, 1); + assert.equal(typeof badgeHost.children[0].listeners.click, "function"); + + context.promptResponses = ["1"]; + await context.result.onMessageRecallBadgeClick(0); + assert.equal(context.alertMessages.length, 1); + + context.promptResponses = ["2", "edited-by-user"]; + await context.result.onMessageRecallBadgeClick(0); + const edited = readPersistedRecallFromUserMessage(chat, 0); + assert.equal(edited?.injectionText, "edited-by-user"); + assert.equal(edited?.manuallyEdited, true); + + context.promptResponses = ["3"]; + await context.result.onMessageRecallBadgeClick(0); + assert.equal(readPersistedRecallFromUserMessage(chat, 0), null); +} + async function testRerollUsesBatchBoundaryRollbackAndPersistsState() { const harness = await createRerollHarness(); harness.chat = [ @@ -1768,6 +2060,10 @@ await testGenerationRecallTransactionDedupesDoubleHookBySameKey(); await testGenerationRecallBeforeCombineRunsStandalone(); await testGenerationRecallDifferentKeyCanRunAgain(); await testGenerationRecallSkippedStateDoesNotLoopToBeforeCombine(); +await testGenerationRecallAppliesFinalInjectionOncePerTransaction(); +await testPersistentRecallDataLayerLifecycleAndCompatibility(); +await testPersistentRecallSourceResolutionAndTargetRouting(); +await testMessageRecallUiBadgeEntryPoints(); await testRerollUsesBatchBoundaryRollbackAndPersistsState(); await testRerollRejectsMissingRecoveryPoint(); await testRerollFallsBackToDirectExtractForUnprocessedFloor();