diff --git a/compressor.js b/compressor.js index 9fd923f..351ae14 100644 --- a/compressor.js +++ b/compressor.js @@ -6,6 +6,18 @@ import { callLLMForJSON } from './llm.js'; import { embedText } from './embedding.js'; import { isDirectVectorConfig } from './vector-index.js'; +function createAbortError(message = '操作已终止') { + const error = new Error(message); + error.name = 'AbortError'; + return error; +} + +function throwIfAborted(signal) { + if (signal?.aborted) { + throw signal.reason instanceof Error ? signal.reason : createAbortError(); + } +} + /** * 对指定类型执行层级压缩 * @@ -16,7 +28,7 @@ import { isDirectVectorConfig } from './vector-index.js'; * @param {boolean} [params.force=false] - 忽略阈值强制压缩 * @returns {Promise<{created: number, archived: number}>} */ -export async function compressType({ graph, typeDef, embeddingConfig, force = false, customPrompt }) { +export async function compressType({ graph, typeDef, embeddingConfig, force = false, customPrompt, signal }) { const compression = typeDef.compression; if (!compression || compression.mode !== 'hierarchical') { return { created: 0, archived: 0 }; @@ -27,6 +39,7 @@ export async function compressType({ graph, typeDef, embeddingConfig, force = fa // 从最低层级开始逐层压缩 for (let level = 0; level < compression.maxDepth; level++) { + throwIfAborted(signal); const result = await compressLevel({ graph, typeDef, @@ -34,6 +47,7 @@ export async function compressType({ graph, typeDef, embeddingConfig, force = fa embeddingConfig, force, customPrompt, + signal, }); totalCreated += result.created; @@ -49,8 +63,9 @@ export async function compressType({ graph, typeDef, embeddingConfig, force = fa /** * 压缩特定层级的节点 */ -async function compressLevel({ graph, typeDef, level, embeddingConfig, force, customPrompt }) { +async function compressLevel({ graph, typeDef, level, embeddingConfig, force, customPrompt, signal }) { const compression = typeDef.compression; + throwIfAborted(signal); // 获取该层级的活跃叶子节点 const levelNodes = getActiveNodes(graph, typeDef.id) @@ -80,7 +95,7 @@ async function compressLevel({ graph, typeDef, level, embeddingConfig, force, cu if (batch.length < 2) break; // 至少 2 个才压缩 // 调用 LLM 总结 - const summaryResult = await summarizeBatch(batch, typeDef, customPrompt); + const summaryResult = await summarizeBatch(batch, typeDef, customPrompt, signal); if (!summaryResult) continue; // 创建压缩节点 @@ -97,7 +112,7 @@ async function compressLevel({ graph, typeDef, level, embeddingConfig, force, cu // 生成 embedding if (isDirectVectorConfig(embeddingConfig) && summaryResult.fields.summary) { - const vec = await embedText(summaryResult.fields.summary, embeddingConfig); + const vec = await embedText(summaryResult.fields.summary, embeddingConfig, { signal }); if (vec) compressedNode.embedding = Array.from(vec); } @@ -153,7 +168,7 @@ function migrateBatchEdges(graph, batch, compressedNode) { /** * 调用 LLM 总结一批节点 */ -async function summarizeBatch(nodes, typeDef, customPrompt) { +async function summarizeBatch(nodes, typeDef, customPrompt, signal) { const nodeDescriptions = nodes.map((n, i) => { const fieldsStr = Object.entries(n.fields) .filter(([_, v]) => v) @@ -179,7 +194,7 @@ async function summarizeBatch(nodes, typeDef, customPrompt) { const userPrompt = `请压缩以下 ${nodes.length} 个 "${typeDef.label}" 节点:\n\n${nodeDescriptions}`; - return await callLLMForJSON({ systemPrompt, userPrompt, maxRetries: 1 }); + return await callLLMForJSON({ systemPrompt, userPrompt, maxRetries: 1, signal }); } /** @@ -191,13 +206,14 @@ async function summarizeBatch(nodes, typeDef, customPrompt) { * @param {boolean} [force=false] * @returns {Promise<{created: number, archived: number}>} */ -export async function compressAll(graph, schema, embeddingConfig, force = false, customPrompt) { +export async function compressAll(graph, schema, embeddingConfig, force = false, customPrompt, signal) { let totalCreated = 0; let totalArchived = 0; for (const typeDef of schema) { + throwIfAborted(signal); if (typeDef.compression?.mode === 'hierarchical') { - const result = await compressType({ graph, typeDef, embeddingConfig, force, customPrompt }); + const result = await compressType({ graph, typeDef, embeddingConfig, force, customPrompt, signal }); totalCreated += result.created; totalArchived += result.archived; } diff --git a/embedding.js b/embedding.js index 644744e..1fcaf32 100644 --- a/embedding.js +++ b/embedding.js @@ -8,6 +8,10 @@ const EMBEDDING_REQUEST_TIMEOUT_MS = 45000; +function isAbortError(error) { + return error?.name === 'AbortError'; +} + function normalizeOpenAICompatibleBaseUrl(value) { return String(value || '') .trim() @@ -63,7 +67,7 @@ async function fetchWithTimeout(url, options = {}, timeoutMs = EMBEDDING_REQUEST * @param {string} config.model - 模型名(如 text-embedding-3-small) * @returns {Promise} 向量或 null */ -export async function embedText(text, config) { +export async function embedText(text, config, { signal } = {}) { const apiUrl = normalizeOpenAICompatibleBaseUrl(config?.apiUrl); if (!text || !apiUrl || !config?.model) { console.warn('[ST-BME] Embedding 配置不完整,跳过'); @@ -77,6 +81,7 @@ export async function embedText(text, config) { 'Content-Type': 'application/json', ...(config.apiKey ? { Authorization: `Bearer ${config.apiKey}` } : {}), }, + signal, body: JSON.stringify({ model: config.model, input: text, @@ -99,6 +104,9 @@ export async function embedText(text, config) { return new Float64Array(vector); } catch (e) { + if (isAbortError(e)) { + throw e; + } console.error('[ST-BME] Embedding API 调用失败:', e); return null; } @@ -111,7 +119,7 @@ export async function embedText(text, config) { * @param {object} config * @returns {Promise<(Float64Array|null)[]>} */ -export async function embedBatch(texts, config) { +export async function embedBatch(texts, config, { signal } = {}) { const apiUrl = normalizeOpenAICompatibleBaseUrl(config?.apiUrl); if (!texts.length || !apiUrl || !config?.model) { return texts.map(() => null); @@ -124,6 +132,7 @@ export async function embedBatch(texts, config) { 'Content-Type': 'application/json', ...(config.apiKey ? { Authorization: `Bearer ${config.apiKey}` } : {}), }, + signal, body: JSON.stringify({ model: config.model, input: texts, @@ -153,6 +162,9 @@ export async function embedBatch(texts, config) { return null; }); } catch (e) { + if (isAbortError(e)) { + throw e; + } console.error('[ST-BME] Embedding API 批量调用失败:', e); return texts.map(() => null); } diff --git a/evolution.js b/evolution.js index 2f81a73..fd4d479 100644 --- a/evolution.js +++ b/evolution.js @@ -9,6 +9,22 @@ import { validateVectorConfig, } from './vector-index.js'; +function createAbortError(message = '操作已终止') { + const error = new Error(message); + error.name = 'AbortError'; + return error; +} + +function isAbortError(error) { + return error?.name === 'AbortError'; +} + +function throwIfAborted(signal) { + if (signal?.aborted) { + throw signal.reason instanceof Error ? signal.reason : createAbortError(); + } +} + /** * 进化系统提示词 * 参考 A-MEM process_memory() 的进化决策 Prompt @@ -57,6 +73,7 @@ export async function evolveMemories({ embeddingConfig, options = {}, customPrompt, + signal, }) { const neighborCount = options.neighborCount ?? 5; const stats = { evolved: 0, connections: 0, updates: 0 }; @@ -71,6 +88,7 @@ export async function evolveMemories({ if (activeNodes.length < 2) return stats; // 至少需要 2 个节点才有进化意义 for (const newId of newNodeIds) { + throwIfAborted(signal); const newNode = getNode(graph, newId); if (!newNode) continue; @@ -86,6 +104,7 @@ export async function evolveMemories({ embeddingConfig, neighborCount, candidates, + signal, ); if (neighbors.length === 0) continue; @@ -118,6 +137,7 @@ export async function evolveMemories({ systemPrompt: customPrompt || EVOLUTION_SYSTEM_PROMPT, userPrompt, maxRetries: 1, + signal, }); if (!decision || !decision.should_evolve) continue; @@ -188,6 +208,9 @@ export async function evolveMemories({ } } catch (e) { + if (isAbortError(e)) { + throw e; + } console.error(`[ST-BME] 记忆进化失败 (${newId}):`, e); } } diff --git a/extractor.js b/extractor.js index c359529..af3f13e 100644 --- a/extractor.js +++ b/extractor.js @@ -27,6 +27,24 @@ import { validateVectorConfig, } from "./vector-index.js"; +function createAbortError(message = "操作已终止") { + const error = new Error(message); + error.name = "AbortError"; + return error; +} + +function isAbortError(error) { + return error?.name === "AbortError"; +} + +function throwIfAborted(signal) { + if (signal?.aborted) { + throw signal.reason instanceof Error + ? signal.reason + : createAbortError(); + } +} + /** * 对未处理的对话楼层执行记忆提取 * @@ -52,7 +70,9 @@ export async function extractMemories({ embeddingConfig, extractPrompt, v2Options = {}, + signal = undefined, }) { + throwIfAborted(signal); if (!messages || messages.length === 0) { return { success: true, @@ -117,7 +137,9 @@ export async function extractMemories({ systemPrompt, userPrompt, maxRetries: 2, + signal, }); + throwIfAborted(signal); if (!result || !Array.isArray(result.operations)) { console.warn("[ST-BME] 提取 LLM 未返回有效操作"); @@ -140,6 +162,7 @@ export async function extractMemories({ embeddingConfig, conflictThreshold, effectiveEndSeq, + signal, ); } @@ -182,8 +205,11 @@ export async function extractMemories({ // 为新建节点生成 embedding。失败不应回滚整批图谱写入。 try { - await generateNodeEmbeddings(graph, embeddingConfig); + await generateNodeEmbeddings(graph, embeddingConfig, signal); } catch (error) { + if (isAbortError(error)) { + throw error; + } console.error("[ST-BME] 节点 embedding 生成失败,保留图谱写入:", error); } @@ -432,8 +458,9 @@ function handleLinks(graph, sourceId, links, refMap, stats) { /** * 为缺少 embedding 的节点生成向量 */ -async function generateNodeEmbeddings(graph, embeddingConfig) { +async function generateNodeEmbeddings(graph, embeddingConfig, signal) { if (!isDirectVectorConfig(embeddingConfig)) return; + throwIfAborted(signal); const needsEmbedding = graph.nodes.filter( (n) => @@ -446,7 +473,7 @@ async function generateNodeEmbeddings(graph, embeddingConfig) { console.log(`[ST-BME] 为 ${texts.length} 个节点生成 embedding`); - const embeddings = await embedBatch(texts, embeddingConfig); + const embeddings = await embedBatch(texts, embeddingConfig, { signal }); for (let i = 0; i < needsEmbedding.length; i++) { if (embeddings[i]) { @@ -553,6 +580,7 @@ async function mem0ConflictCheck( embeddingConfig, threshold, fallbackSeq, + signal, ) { const activeNodes = getActiveNodes(graph).filter((node) => { const text = buildNodeVectorText(node); @@ -568,12 +596,14 @@ async function mem0ConflictCheck( if (!factText) continue; try { + throwIfAborted(signal); const similar = await findSimilarNodesByText( graph, factText, embeddingConfig, 3, activeNodes, + signal, ); if (similar.length > 0 && similar[0].score > threshold) { @@ -598,6 +628,7 @@ async function mem0ConflictCheck( `相似度: ${similar[0].score.toFixed(3)}`, ].join("\n"), maxRetries: 1, + signal, }); if (decision?.action === "update" && decision.targetId) { @@ -617,6 +648,9 @@ async function mem0ConflictCheck( } } } catch (e) { + if (isAbortError(e)) { + throw e; + } console.warn("[ST-BME] Mem0对照失败,保持原操作:", e.message); } } @@ -632,7 +666,7 @@ async function mem0ConflictCheck( * @param {number} params.currentSeq * @returns {Promise} */ -export async function generateSynopsis({ graph, schema, currentSeq, customPrompt }) { +export async function generateSynopsis({ graph, schema, currentSeq, customPrompt, signal }) { const eventNodes = getActiveNodes(graph, "event").sort( (a, b) => a.seq - b.seq, ); @@ -670,6 +704,7 @@ export async function generateSynopsis({ graph, schema, currentSeq, customPrompt threadSummary || "(无)", ].join("\n"), maxRetries: 1, + signal, }); if (!result?.summary) return; @@ -701,7 +736,7 @@ export async function generateSynopsis({ graph, schema, currentSeq, customPrompt } } -export async function generateReflection({ graph, currentSeq, customPrompt }) { +export async function generateReflection({ graph, currentSeq, customPrompt, signal }) { const recentEvents = getActiveNodes(graph, "event") .sort((a, b) => b.seq - a.seq) .slice(0, 6) @@ -763,6 +798,7 @@ export async function generateReflection({ graph, currentSeq, customPrompt }) { contradictionSummary || "(无)", ].join("\n"), maxRetries: 1, + signal, }); if (!result?.insight) return null; diff --git a/index.js b/index.js index d6fc5b4..4379e15 100644 --- a/index.js +++ b/index.js @@ -196,6 +196,12 @@ const stageNoticeHandles = { recall: null, history: null, }; +const stageAbortControllers = { + extraction: null, + vector: null, + recall: null, + history: null, +}; function createUiStatus(text = "待命", meta = "", level = "idle") { return { @@ -214,6 +220,87 @@ function normalizeStageNoticeLevel(level = "info") { return "info"; } +function createAbortError(message = "操作已终止") { + const error = new Error(message); + error.name = "AbortError"; + return error; +} + +function isAbortError(error) { + return error?.name === "AbortError"; +} + +function throwIfAborted(signal, message = "操作已终止") { + if (signal?.aborted) { + throw signal.reason instanceof Error + ? signal.reason + : createAbortError(message); + } +} + +function getStageAbortLabel(stage) { + switch (stage) { + case "extraction": + return "提取"; + case "vector": + return "向量"; + case "recall": + return "召回"; + case "history": + return "历史恢复"; + default: + return "当前流程"; + } +} + +function beginStageAbortController(stage) { + const controller = new AbortController(); + stageAbortControllers[stage] = controller; + return controller; +} + +function finishStageAbortController(stage, controller = null) { + if (!controller || stageAbortControllers[stage] === controller) { + stageAbortControllers[stage] = null; + } +} + +function findAbortableStageForNotice(stage) { + const preferred = [stage]; + if (stage === "vector") { + preferred.push("history", "extraction", "recall"); + } + + for (const candidate of preferred) { + const controller = stageAbortControllers[candidate]; + if (controller && !controller.signal.aborted) { + return candidate; + } + } + + return null; +} + +function abortStage(stage) { + const controller = stageAbortControllers[stage]; + if (!controller || controller.signal.aborted) return false; + controller.abort(createAbortError(`${getStageAbortLabel(stage)}已终止`)); + return true; +} + +function buildAbortStageAction(stage) { + const abortStageName = findAbortableStageForNotice(stage); + if (!abortStageName) return undefined; + + return { + label: `终止${getStageAbortLabel(abortStageName)}`, + kind: "danger", + onClick: () => { + abortStage(abortStageName); + }, + }; +} + function getStageNoticeTitle(stage) { switch (stage) { case "extraction": @@ -264,6 +351,12 @@ function dismissAllStageNotices() { } } +function abortAllRunningStages() { + for (const stage of Object.keys(stageAbortControllers)) { + abortStage(stage); + } +} + function updateStageNotice( stage, text, @@ -284,9 +377,12 @@ function updateStageNotice( persist, duration_ms: options.duration_ms ?? getStageNoticeDuration(noticeLevel), action: - options.action === undefined && - (noticeLevel === "warning" || noticeLevel === "error") - ? createNoticePanelAction() + options.action === undefined + ? (busy + ? buildAbortStageAction(stage) + : (noticeLevel === "warning" || noticeLevel === "error") + ? createNoticePanelAction() + : undefined) : options.action, }; @@ -707,12 +803,14 @@ async function recordGraphMutation({ processedRange = null, artifactTags = [], syncRange = null, + signal = undefined, } = {}) { ensureCurrentGraphRuntimeState(); const vectorSync = await syncVectorState({ force: true, purge: isBackendVectorConfig(getEmbeddingConfig()) && !syncRange, range: syncRange, + signal, }); const afterSnapshot = cloneGraphSnapshot(currentGraph); const effectiveRange = Array.isArray(processedRange) @@ -784,6 +882,7 @@ async function syncVectorState({ force = false, purge = false, range = null, + signal = undefined, } = {}) { ensureCurrentGraphRuntimeState(); const scopeLabel = @@ -818,6 +917,7 @@ async function syncVectorState({ force, purge, range, + signal, }); setLastVectorStatus( "向量完成", @@ -827,6 +927,17 @@ async function syncVectorState({ ); return result; } catch (error) { + if (isAbortError(error)) { + setLastVectorStatus("向量已终止", scopeLabel, "warning", { + syncRuntime: false, + }); + return { + insertedHashes: [], + stats: getVectorIndexStats(currentGraph), + error: error?.message || "向量任务已终止", + aborted: true, + }; + } const message = error?.message || String(error) || "向量同步失败"; markVectorStateDirty(message); console.error("[ST-BME] 向量同步失败:", error); @@ -842,7 +953,7 @@ async function syncVectorState({ } } -async function ensureVectorReadyIfNeeded(reason = "vector-ready-check") { +async function ensureVectorReadyIfNeeded(reason = "vector-ready-check", signal = undefined) { if (!currentGraph) return; ensureCurrentGraphRuntimeState(); @@ -855,6 +966,7 @@ async function ensureVectorReadyIfNeeded(reason = "vector-ready-check") { const result = await syncVectorState({ force: true, purge: isBackendVectorConfig(config), + signal, }); if (result?.error) { @@ -993,6 +1105,7 @@ function updateModuleSettings(patch = {}) { Object.prototype.hasOwnProperty.call(patch, "enabled") && patch.enabled === false ) { + abortAllRunningStages(); dismissAllStageNotices(); try { const context = getContext(); @@ -1309,9 +1422,10 @@ function getCurrentChatSeq(context = getContext()) { return currentGraph?.lastProcessedSeq ?? 0; } -async function handleExtractionSuccess(result, endIdx, settings) { +async function handleExtractionSuccess(result, endIdx, settings, signal = undefined) { const postProcessArtifacts = []; const warnings = []; + throwIfAborted(signal, "提取已终止"); extractionCount++; updateLastExtractedItems(result.newNodeIds || []); @@ -1323,9 +1437,11 @@ async function handleExtractionSuccess(result, endIdx, settings) { embeddingConfig: getEmbeddingConfig(), options: { neighborCount: settings.evoNeighborCount }, customPrompt: settings.evolutionPrompt || undefined, + signal, }); postProcessArtifacts.push("evolution"); } catch (e) { + if (isAbortError(e)) throw e; console.error("[ST-BME] 记忆进化失败:", e); } } @@ -1337,9 +1453,11 @@ async function handleExtractionSuccess(result, endIdx, settings) { schema: getSchema(), currentSeq: endIdx, customPrompt: settings.synopsisPrompt || undefined, + signal, }); postProcessArtifacts.push("synopsis"); } catch (e) { + if (isAbortError(e)) throw e; console.error("[ST-BME] 概要生成失败:", e); } } @@ -1353,9 +1471,11 @@ async function handleExtractionSuccess(result, endIdx, settings) { graph: currentGraph, currentSeq: endIdx, customPrompt: settings.reflectionPrompt || undefined, + signal, }); postProcessArtifacts.push("reflection"); } catch (e) { + if (isAbortError(e)) throw e; console.error("[ST-BME] 反思生成失败:", e); } } @@ -1370,23 +1490,29 @@ async function handleExtractionSuccess(result, endIdx, settings) { } try { + throwIfAborted(signal, "提取已终止"); const compressionResult = await compressAll( currentGraph, getSchema(), getEmbeddingConfig(), false, settings.compressPrompt || undefined, + signal, ); if (compressionResult.created > 0 || compressionResult.archived > 0) { postProcessArtifacts.push("compression"); } } catch (error) { + if (isAbortError(error)) throw error; const message = error?.message || String(error) || "压缩阶段失败"; warnings.push(`压缩阶段失败: ${message}`); console.error("[ST-BME] 记忆压缩失败:", error); } - const vectorSync = await syncVectorState(); + const vectorSync = await syncVectorState({ signal }); + if (vectorSync?.aborted) { + throw createAbortError(vectorSync.error || "提取已终止"); + } if (vectorSync?.error) { warnings.push(`向量同步失败: ${vectorSync.error}`); } @@ -1520,12 +1646,13 @@ function inspectHistoryMutation(trigger = "history-change") { return detection; } -async function purgeCurrentVectorCollection() { +async function purgeCurrentVectorCollection(signal = undefined) { if (!currentGraph?.vectorIndexState?.collectionId) return; const response = await fetchLocalWithTimeout("/api/vector/purge", { method: "POST", headers: getRequestHeaders(), + signal, body: JSON.stringify({ collectionId: currentGraph.vectorIndexState.collectionId, }), @@ -1537,14 +1664,17 @@ async function purgeCurrentVectorCollection() { } } -async function prepareVectorStateForReplay(fullReset = false) { +async function prepareVectorStateForReplay(fullReset = false, signal = undefined) { ensureCurrentGraphRuntimeState(); const config = getEmbeddingConfig(); if (isBackendVectorConfig(config)) { try { - await purgeCurrentVectorCollection(); + await purgeCurrentVectorCollection(signal); } catch (error) { + if (isAbortError(error)) { + throw error; + } console.warn("[ST-BME] 清理后端向量索引失败,继续本地恢复:", error); } currentGraph.vectorIndexState.hashToNodeId = {}; @@ -1568,8 +1698,10 @@ async function executeExtractionBatch({ endIdx, settings, smartTriggerDecision = null, + signal = undefined, } = {}) { ensureCurrentGraphRuntimeState(); + throwIfAborted(signal, "提取已终止"); const lastProcessed = getLastProcessedAssistantFloor(); const beforeSnapshot = cloneGraphSnapshot(currentGraph); const messages = buildExtractionMessages(chat, startIdx, endIdx, settings); @@ -1594,6 +1726,7 @@ async function executeExtractionBatch({ enablePreciseConflict: settings.enablePreciseConflict, conflictThreshold: settings.conflictThreshold, }, + signal, }); if (!result.success) { @@ -1605,7 +1738,7 @@ async function executeExtractionBatch({ }; } - const effects = await handleExtractionSuccess(result, endIdx, settings); + const effects = await handleExtractionSuccess(result, endIdx, settings, signal); updateProcessedHistorySnapshot(chat, endIdx); const afterSnapshot = cloneGraphSnapshot(currentGraph); @@ -1632,10 +1765,11 @@ async function executeExtractionBatch({ }; } -async function replayExtractionFromHistory(chat, settings) { +async function replayExtractionFromHistory(chat, settings, signal = undefined) { let replayedBatches = 0; while (true) { + throwIfAborted(signal, "历史恢复已终止"); const pendingAssistantTurns = getAssistantTurns(chat).filter( (index) => index > getLastProcessedAssistantFloor(), ); @@ -1651,6 +1785,7 @@ async function replayExtractionFromHistory(chat, settings) { startIdx, endIdx, settings, + signal, }); if (!batchResult.success) { @@ -1693,6 +1828,8 @@ async function recoverHistoryIfNeeded(trigger = "history-recovery") { : detection.earliestAffectedFloor; let replayedBatches = 0; let usedFullRebuild = false; + const historyController = beginStageAbortController("history"); + const historySignal = historyController.signal; updateStageNotice( "history", @@ -1708,6 +1845,7 @@ async function recoverHistoryIfNeeded(trigger = "history-recovery") { ); try { + throwIfAborted(historySignal, "历史恢复已终止"); const recoveryPoint = findJournalRecoveryPoint(currentGraph, initialDirtyFrom); if (recoveryPoint) { currentGraph = normalizeGraphRuntimeState( @@ -1719,8 +1857,8 @@ async function recoverHistoryIfNeeded(trigger = "history-recovery") { usedFullRebuild = true; } - await prepareVectorStateForReplay(usedFullRebuild); - replayedBatches = await replayExtractionFromHistory(chat, settings); + await prepareVectorStateForReplay(usedFullRebuild, historySignal); + replayedBatches = await replayExtractionFromHistory(chat, settings, historySignal); clearHistoryDirty( currentGraph, @@ -1750,12 +1888,26 @@ async function recoverHistoryIfNeeded(trigger = "history-recovery") { ); return true; } catch (error) { + if (isAbortError(error)) { + updateStageNotice( + "history", + "历史恢复已终止", + error?.message || "已手动终止当前恢复流程", + "warning", + { + busy: false, + persist: false, + }, + ); + saveGraphToChat(); + return false; + } console.error("[ST-BME] 历史恢复失败,尝试全量重建:", error); try { currentGraph = normalizeGraphRuntimeState(createEmptyGraph(), chatId); - await prepareVectorStateForReplay(true); - replayedBatches = await replayExtractionFromHistory(chat, settings); + await prepareVectorStateForReplay(true, historySignal); + replayedBatches = await replayExtractionFromHistory(chat, settings, historySignal); clearHistoryDirty( currentGraph, buildRecoveryResult("full-rebuild", { @@ -1799,6 +1951,7 @@ async function recoverHistoryIfNeeded(trigger = "history-recovery") { return false; } } finally { + finishStageAbortController("history", historyController); isRecoveringHistory = false; } } @@ -1850,6 +2003,8 @@ async function runExtraction() { ); isExtracting = true; + const extractionController = beginStageAbortController("extraction"); + const extractionSignal = extractionController.signal; try { const batchResult = await executeExtractionBatch({ @@ -1858,6 +2013,7 @@ async function runExtraction() { endIdx, settings, smartTriggerDecision, + signal: extractionSignal, }); if (!batchResult.success) { @@ -1877,9 +2033,16 @@ async function runExtraction() { { syncRuntime: true }, ); } catch (e) { + if (isAbortError(e)) { + setLastExtractionStatus("提取已终止", e?.message || "已手动终止当前提取", "warning", { + syncRuntime: true, + }); + return; + } console.error("[ST-BME] 提取失败:", e); notifyExtractionIssue(e?.message || String(e) || "自动提取失败"); } finally { + finishStageAbortController("extraction", extractionController); isExtracting = false; } } @@ -1894,15 +2057,16 @@ async function runRecall() { if (!settings.enabled || !settings.recallEnabled) return; if (!(await recoverHistoryIfNeeded("pre-recall"))) return; - await ensureVectorReadyIfNeeded("pre-recall"); - const context = getContext(); const chat = context.chat; if (!chat || chat.length === 0) return; isRecalling = true; + const recallController = beginStageAbortController("recall"); + const recallSignal = recallController.signal; try { + await ensureVectorReadyIfNeeded("pre-recall", recallSignal); const recentContextMessageLimit = clampInt( settings.recallLlmContextMessages, 4, @@ -1937,6 +2101,7 @@ async function runRecall() { recentMessages, embeddingConfig: getEmbeddingConfig(), schema: getSchema(), + signal: recallSignal, options: { topK: settings.recallTopK, maxRecallNodes: settings.recallMaxNodes, @@ -2020,6 +2185,12 @@ async function runRecall() { } } } catch (e) { + if (isAbortError(e)) { + setLastRecallStatus("召回已终止", e?.message || "已手动终止当前召回", "warning", { + syncRuntime: true, + }); + return; + } console.error("[ST-BME] 召回失败:", e); const message = e?.message || String(e); setLastRecallStatus("召回失败", message, "error", { @@ -2028,6 +2199,7 @@ async function runRecall() { }); toastr.error(`召回失败: ${message}`); } finally { + finishStageAbortController("recall", recallController); isRecalling = false; refreshPanelLiveState(); } @@ -2039,6 +2211,7 @@ function onChatChanged() { clearTimeout(pendingHistoryRecoveryTimer); pendingHistoryRecoveryTimer = null; pendingHistoryRecoveryTrigger = ""; + abortAllRunningStages(); dismissAllStageNotices(); loadGraphFromChat(); clearInjectionState(); @@ -2375,6 +2548,8 @@ async function onManualExtract() { const warnings = []; isExtracting = true; + const extractionController = beginStageAbortController("extraction"); + const extractionSignal = extractionController.signal; setLastExtractionStatus( "手动提取中", `待处理 assistant 楼层 ${pendingAssistantTurns.length} 条`, @@ -2396,6 +2571,7 @@ async function onManualExtract() { startIdx, endIdx, settings, + signal: extractionSignal, }); if (!batchResult.success) { @@ -2438,6 +2614,12 @@ async function onManualExtract() { ); } } catch (e) { + if (isAbortError(e)) { + setLastExtractionStatus("手动提取已终止", e?.message || "已手动终止当前提取", "warning", { + syncRuntime: true, + }); + return; + } console.error("[ST-BME] 手动提取失败:", e); setLastExtractionStatus("手动提取失败", e?.message || String(e), "error", { syncRuntime: true, @@ -2446,6 +2628,7 @@ async function onManualExtract() { }); toastr.error(`手动提取失败: ${e.message || e}`); } finally { + finishStageAbortController("extraction", extractionController); isExtracting = false; } } @@ -2510,21 +2693,30 @@ async function onRebuildVectorIndex(range = null) { return; } - const result = await syncVectorState({ - force: true, - purge: isBackendVectorConfig(config) && !range, - range, - }); + const vectorController = beginStageAbortController("vector"); + try { + const result = await syncVectorState({ + force: true, + purge: isBackendVectorConfig(config) && !range, + range, + signal: vectorController.signal, + }); - saveGraphToChat(); - if (result?.error) { - throw new Error(result.error); + saveGraphToChat(); + if (result?.aborted) { + return; + } + if (result?.error) { + throw new Error(result.error); + } + toastr.success( + range + ? `范围向量重建完成:indexed=${result.stats.indexed}, pending=${result.stats.pending}` + : `当前聊天向量重建完成:indexed=${result.stats.indexed}, pending=${result.stats.pending}`, + ); + } finally { + finishStageAbortController("vector", vectorController); } - toastr.success( - range - ? `范围向量重建完成:indexed=${result.stats.indexed}, pending=${result.stats.pending}` - : `当前聊天向量重建完成:indexed=${result.stats.indexed}, pending=${result.stats.pending}`, - ); } async function onReembedDirect() { diff --git a/llm.js b/llm.js index a90f1f6..1f4b5c6 100644 --- a/llm.js +++ b/llm.js @@ -254,6 +254,10 @@ function createCombinedAbortSignal(...signals) { // 自动检测:如果 API 不支持 response_format,记住并跳过 let _jsonModeSupported = true; +function isAbortError(error) { + return error?.name === 'AbortError'; +} + async function callDedicatedOpenAICompatible( messages, { signal, jsonMode = false, maxCompletionTokens = null } = {}, @@ -364,7 +368,7 @@ async function _parseResponse(response) { * @param {string} [params.model] - 指定模型(留空使用当前配置) * @returns {Promise} 解析后的 JSON 对象,或 null */ -export async function callLLMForJSON({ systemPrompt, userPrompt, maxRetries = 2 }) { +export async function callLLMForJSON({ systemPrompt, userPrompt, maxRetries = 2, signal } = {}) { let lastFailureReason = ''; for (let attempt = 0; attempt <= maxRetries; attempt++) { @@ -376,6 +380,7 @@ export async function callLLMForJSON({ systemPrompt, userPrompt, maxRetries = 2 lastFailureReason, ); const response = await callDedicatedOpenAICompatible(messages, { + signal, jsonMode: true, maxCompletionTokens: attempt === 0 ? DEFAULT_JSON_COMPLETION_TOKENS @@ -404,6 +409,9 @@ export async function callLLMForJSON({ systemPrompt, userPrompt, maxRetries = 2 responseText.slice(0, 200), ); } catch (e) { + if (isAbortError(e)) { + throw e; + } console.error(`[ST-BME] LLM 调用失败 (尝试 ${attempt + 1}):`, e); lastFailureReason = e?.message || String(e) || 'LLM 调用失败'; } diff --git a/retriever.js b/retriever.js index 3ada8e7..aab48e3 100644 --- a/retriever.js +++ b/retriever.js @@ -13,6 +13,24 @@ import { import { callLLMForJSON } from "./llm.js"; import { findSimilarNodesByText, validateVectorConfig } from "./vector-index.js"; +function createAbortError(message = "操作已终止") { + const error = new Error(message); + error.name = "AbortError"; + return error; +} + +function isAbortError(error) { + return error?.name === "AbortError"; +} + +function throwIfAborted(signal) { + if (signal?.aborted) { + throw signal.reason instanceof Error + ? signal.reason + : createAbortError(); + } +} + /** * 三层混合检索管线 * @@ -31,8 +49,10 @@ export async function retrieve({ recentMessages = [], embeddingConfig, schema, + signal = undefined, options = {}, }) { + throwIfAborted(signal); const topK = options.topK ?? 20; const maxRecallNodes = options.maxRecallNodes ?? 8; const enableLLMRecall = options.enableLLMRecall ?? true; @@ -111,6 +131,7 @@ export async function retrieve({ activeNodes, embeddingConfig, normalizedTopK, + signal, ); } @@ -234,6 +255,7 @@ export async function retrieve({ schema, normalizedMaxRecallNodes, options.recallPrompt, + signal, ); selectedNodeIds = llmResult.selectedNodeIds; llmMeta = { @@ -317,6 +339,7 @@ async function vectorPreFilter( activeNodes, embeddingConfig, topK, + signal, ) { try { return await findSimilarNodesByText( @@ -325,8 +348,12 @@ async function vectorPreFilter( embeddingConfig, topK, activeNodes, + signal, ); } catch (e) { + if (isAbortError(e)) { + throw e; + } console.error("[ST-BME] 向量预筛失败:", e); return []; } @@ -370,7 +397,9 @@ async function llmRecall( schema, maxNodes, customPrompt, + signal, ) { + throwIfAborted(signal); const contextStr = recentMessages.join("\n---\n"); const candidateDescriptions = candidates .map((c) => { @@ -410,6 +439,7 @@ async function llmRecall( systemPrompt, userPrompt, maxRetries: 1, + signal, }); if (result?.selected_ids && Array.isArray(result.selected_ids)) { diff --git a/vector-index.js b/vector-index.js index a6aa1e4..1f4830b 100644 --- a/vector-index.js +++ b/vector-index.js @@ -41,6 +41,18 @@ const BACKEND_STATUS_MODEL_SOURCES = { mistral: "mistralai", }; +function isAbortError(error) { + return error?.name === "AbortError"; +} + +function throwIfAborted(signal) { + if (signal?.aborted) { + throw signal.reason instanceof Error + ? signal.reason + : Object.assign(new Error("操作已终止"), { name: "AbortError" }); + } +} + export const BACKEND_DEFAULT_MODELS = { openai: "text-embedding-3-small", openrouter: "openai/text-embedding-3-small", @@ -338,10 +350,12 @@ function computeVectorStats(graph, desiredEntries) { }; } -async function purgeVectorCollection(collectionId) { +async function purgeVectorCollection(collectionId, signal) { + throwIfAborted(signal); const response = await fetchWithTimeout("/api/vector/purge", { method: "POST", headers: getRequestHeaders(), + signal, body: JSON.stringify({ collectionId }), }); @@ -351,12 +365,14 @@ async function purgeVectorCollection(collectionId) { } } -async function deleteVectorHashes(collectionId, config, hashes) { +async function deleteVectorHashes(collectionId, config, hashes, signal) { if (!Array.isArray(hashes) || hashes.length === 0) return; + throwIfAborted(signal); const response = await fetchWithTimeout("/api/vector/delete", { method: "POST", headers: getRequestHeaders(), + signal, body: JSON.stringify({ collectionId, hashes, @@ -370,12 +386,14 @@ async function deleteVectorHashes(collectionId, config, hashes) { } } -async function insertVectorEntries(collectionId, config, entries) { +async function insertVectorEntries(collectionId, config, entries, signal) { if (!Array.isArray(entries) || entries.length === 0) return; + throwIfAborted(signal); const response = await fetchWithTimeout("/api/vector/insert", { method: "POST", headers: getRequestHeaders(), + signal, body: JSON.stringify({ collectionId, items: entries.map((entry) => ({ @@ -410,11 +428,13 @@ export async function syncGraphVectorIndex( purge = false, force = false, range = null, + signal = undefined, } = {}, ) { if (!graph || !config) { return { insertedHashes: [], stats: { total: 0, indexed: 0, stale: 0, pending: 0 } }; } + throwIfAborted(signal); const validation = validateVectorConfig(config); if (!validation.valid) { @@ -443,9 +463,9 @@ export async function syncGraphVectorIndex( const fullReset = purge || state.dirty || scopeChanged || (force && !hasConcreteRange); if (fullReset) { - await purgeVectorCollection(collectionId); + await purgeVectorCollection(collectionId, signal); resetVectorMappings(graph, config, chatId); - await insertVectorEntries(collectionId, config, desiredEntries); + await insertVectorEntries(collectionId, config, desiredEntries, signal); for (const entry of desiredEntries) { state.hashToNodeId[entry.hash] = entry.nodeId; state.nodeToHash[entry.nodeId] = entry.hash; @@ -485,8 +505,8 @@ export async function syncGraphVectorIndex( entriesToInsert.push(entry); } - await deleteVectorHashes(collectionId, config, hashesToDelete); - await insertVectorEntries(collectionId, config, entriesToInsert); + await deleteVectorHashes(collectionId, config, hashesToDelete, signal); + await insertVectorEntries(collectionId, config, entriesToInsert, signal); for (const entry of entriesToInsert) { state.hashToNodeId[entry.hash] = entry.nodeId; @@ -536,9 +556,11 @@ export async function syncGraphVectorIndex( } if (entriesToEmbed.length > 0) { + throwIfAborted(signal); const embeddings = await embedBatch( entriesToEmbed.map((entry) => entry.text), config, + { signal }, ); for (let index = 0; index < entriesToEmbed.length; index++) { @@ -578,8 +600,10 @@ export async function findSimilarNodesByText( config, topK = 10, candidates = null, + signal = undefined, ) { if (!text || !graph || !config) return []; + throwIfAborted(signal); const candidateNodes = Array.isArray(candidates) ? candidates @@ -588,7 +612,7 @@ export async function findSimilarNodesByText( if (candidateNodes.length === 0) return []; if (isDirectVectorConfig(config)) { - const queryVec = await embedText(text, config); + const queryVec = await embedText(text, config, { signal }); if (!queryVec) return []; return searchSimilar( @@ -609,6 +633,7 @@ export async function findSimilarNodesByText( const response = await fetchWithTimeout("/api/vector/query", { method: "POST", headers: getRequestHeaders(), + signal, body: JSON.stringify({ collectionId: graph.vectorIndexState.collectionId, searchText: text,