From be38a2c0d20186c6637b568b469e0748cae2eb73 Mon Sep 17 00:00:00 2001 From: opencode Date: Fri, 15 May 2026 09:55:00 +0000 Subject: [PATCH] fix(vector): harden runtime embedding sync --- host/event-binding.js | 1 + index.js | 130 +++++++++--- runtime/vector-sync-coalescer.js | 133 ++++++++++++ tests/embedding-batch.mjs | 80 ++++++++ tests/vector-connection-probe.mjs | 58 ++++++ tests/vector-sync-coalescer.mjs | 44 ++++ vector/authority-vector-primary-adapter.js | 26 ++- vector/embedding.js | 223 ++++++++++++++------- vector/vector-index.js | 100 +++++---- 9 files changed, 652 insertions(+), 143 deletions(-) create mode 100644 runtime/vector-sync-coalescer.js create mode 100644 tests/embedding-batch.mjs create mode 100644 tests/vector-connection-probe.mjs create mode 100644 tests/vector-sync-coalescer.mjs diff --git a/host/event-binding.js b/host/event-binding.js index 64c0dad..71dc555 100644 --- a/host/event-binding.js +++ b/host/event-binding.js @@ -239,6 +239,7 @@ export function onChatChangedController(runtime) { runtime.setPendingHistoryRecoveryTimer(null); runtime.setPendingHistoryRecoveryTrigger(""); runtime.clearPendingAutoExtraction?.(); + runtime.clearPendingBackgroundVectorSync?.(); runtime.clearPendingGraphLoadRetry(); runtime.setSkipBeforeCombineRecallUntil(0); runtime.setLastPreGenerationRecallKey(""); diff --git a/index.js b/index.js index 13a1eab..9139997 100644 --- a/index.js +++ b/index.js @@ -257,6 +257,7 @@ import { writePersistedRecallToUserMessage, } from "./retrieval/recall-persistence.js"; import { resolveConfiguredTimeoutMs } from "./runtime/request-timeout.js"; +import { createVectorSyncCoalescer as createImportedVectorSyncCoalescer } from "./runtime/vector-sync-coalescer.js"; import { defaultSettings, getPersistedSettingsSnapshot, @@ -1308,6 +1309,45 @@ const backgroundMaintenanceQueue = typeof createBackgroundMaintenanceQueue === "function" ? createBackgroundMaintenanceQueue() : null; +const backgroundVectorSyncCoalescer = + typeof createImportedVectorSyncCoalescer === "function" + ? createImportedVectorSyncCoalescer() + : { + clear() {}, + getActive() { + return null; + }, + getPending() { + return null; + }, + enqueue(task = {}) { + return { + scheduled: true, + coalesced: false, + task: { + ...(task || {}), + stale: false, + }, + }; + }, + start(task = null) { + return Boolean(task && !task.stale); + }, + complete() { + return true; + }, + drop(task = null) { + if (task) task.stale = true; + return Boolean(task); + }, + isStale(task = null, chatId = "") { + return Boolean( + !task || + task.stale || + (chatId && task.chatId && String(chatId) !== String(task.chatId)), + ); + }, + }; const lastStatusToastAt = {}; let pendingRecallSendIntent = createRecallInputRecord(); let lastRecallSentUserMessage = createRecallInputRecord(); @@ -16777,50 +16817,81 @@ async function syncVectorState({ function scheduleBackgroundVectorSync(task = null, settings = {}) { const normalizedTask = task && typeof task === "object" && !Array.isArray(task) ? task : {}; - const range = - normalizedTask.range && - Number.isFinite(Number(normalizedTask.range.start)) && - Number.isFinite(Number(normalizedTask.range.end)) - ? { - start: Math.floor(Number(normalizedTask.range.start)), - end: Math.floor(Number(normalizedTask.range.end)), - } - : null; - const reason = - String(normalizedTask.reason || "background-vector-sync").trim() || - "background-vector-sync"; + const config = getEmbeddingConfig(); + const chatId = normalizeChatIdCandidate( + normalizedTask.chatId || getCurrentChatId() || graphPersistenceState.chatId, + ); const mode = String( normalizedTask.mode || resolveMaintenancePostProcessConcurrency(settings).mode || "balanced", ).trim() || "balanced"; - return enqueueBackgroundMaintenanceTask( + const coalesced = backgroundVectorSyncCoalescer.enqueue({ + ...normalizedTask, + chatId, + modelScope: getVectorModelScope(config), + mode, + reason: + String(normalizedTask.reason || "background-vector-sync").trim() || + "background-vector-sync", + }); + const scheduledTask = coalesced.task; + + if (!coalesced.scheduled) { + return { + queued: true, + coalesced: true, + id: scheduledTask.id, + snapshot: updateBackgroundMaintenanceQueueState( + typeof backgroundMaintenanceQueue?.getSnapshot === "function" + ? backgroundMaintenanceQueue.getSnapshot() + : null, + ), + }; + } + + const queuedResult = enqueueBackgroundMaintenanceTask( "vector-sync", async () => { - setLastVectorStatus( - "后台向量同步中", - `${mode} 模式 · 正在同步提取后的向量索引`, - "running", - { syncRuntime: false }, - ); - const result = await syncVectorState({ range }); - if (result?.aborted) { - throw createAbortError(result.error || "后台向量同步已终止"); + backgroundVectorSyncCoalescer.start(scheduledTask); + try { + const activeChatId = normalizeChatIdCandidate(getCurrentChatId()); + if (backgroundVectorSyncCoalescer.isStale(scheduledTask, activeChatId)) { + return { skipped: true, reason: "stale-background-vector-sync" }; + } + setLastVectorStatus( + "后台向量同步中", + `${scheduledTask.mode} 模式 · 正在同步提取后的向量索引`, + "running", + { syncRuntime: false }, + ); + const result = await syncVectorState({ range: scheduledTask.range }); + if (result?.aborted) { + throw createAbortError(result.error || "后台向量同步已终止"); + } + if (result?.error) { + throw new Error(result.error); + } + saveGraphToChat({ reason: scheduledTask.reason }); + return result; + } finally { + backgroundVectorSyncCoalescer.complete(scheduledTask); } - if (result?.error) { - throw new Error(result.error); - } - saveGraphToChat({ reason }); - return result; }, settings, { - id: String(normalizedTask.id || ""), + id: scheduledTask.id, }, ); + if (queuedResult?.queued !== true) { + backgroundVectorSyncCoalescer.drop?.( + scheduledTask, + queuedResult?.reason || "background-vector-sync-queue-rejected", + ); + } + return queuedResult; } - function hasPlanCommitChanges(planCommit = null) { if (!planCommit || typeof planCommit !== "object") return false; return [ @@ -22349,6 +22420,7 @@ function onChatChanged() { clearGenerationRecallTransactionsForChat, clearInjectionState, clearPendingAutoExtraction, + clearPendingBackgroundVectorSync: () => backgroundVectorSyncCoalescer.clear("chat-changed"), clearPendingGraphLoadRetry, clearPendingHistoryMutationChecks, clearCurrentGenerationTrivialSkip, diff --git a/runtime/vector-sync-coalescer.js b/runtime/vector-sync-coalescer.js new file mode 100644 index 0000000..9dd5203 --- /dev/null +++ b/runtime/vector-sync-coalescer.js @@ -0,0 +1,133 @@ +export function normalizeVectorSyncRange(range = null) { + if ( + range && + Number.isFinite(Number(range.start)) && + Number.isFinite(Number(range.end)) + ) { + const start = Math.floor(Number(range.start)); + const end = Math.floor(Number(range.end)); + return { + start: Math.min(start, end), + end: Math.max(start, end), + }; + } + return null; +} + +export function mergeVectorSyncRange(current = null, next = null) { + const currentRange = normalizeVectorSyncRange(current); + const nextRange = normalizeVectorSyncRange(next); + if (!currentRange || !nextRange) return null; + return { + start: Math.min(currentRange.start, nextRange.start), + end: Math.max(currentRange.end, nextRange.end), + }; +} + +function createTaskRecord(task = {}) { + const id = String(task.id || `vector-sync:${Date.now()}`); + return { + id, + chatId: String(task.chatId || "").trim(), + modelScope: String(task.modelScope || "").trim(), + range: normalizeVectorSyncRange(task.range), + reason: + String(task.reason || "background-vector-sync").trim() || + "background-vector-sync", + mode: String(task.mode || "balanced").trim() || "balanced", + stale: false, + requestedAt: Date.now(), + updatedAt: Date.now(), + }; +} + +function canMergeTask(left = null, right = null) { + return Boolean( + left && + right && + !left.stale && + left.chatId === right.chatId && + left.modelScope === right.modelScope, + ); +} + +function mergeTaskInto(target, incoming) { + target.range = mergeVectorSyncRange(target.range, incoming.range); + target.reason = + target.reason === incoming.reason + ? target.reason + : `${target.reason}+${incoming.reason}`; + target.mode = incoming.mode || target.mode; + target.updatedAt = Date.now(); + return target; +} + +function markStale(task = null, reason = "stale") { + if (!task) return; + task.stale = true; + task.clearReason = String(reason || "stale"); +} + +export function createVectorSyncCoalescer() { + let active = null; + let pending = null; + + return { + clear(reason = "clear") { + markStale(active, reason); + markStale(pending, reason); + active = null; + pending = null; + }, + getActive() { + return active; + }, + getPending() { + return pending; + }, + enqueue(task = {}) { + const incoming = createTaskRecord(task); + if (canMergeTask(active, incoming)) { + if (canMergeTask(pending, incoming)) { + mergeTaskInto(pending, incoming); + return { scheduled: false, coalesced: true, task: pending }; + } + markStale(pending, "replaced"); + pending = incoming; + return { scheduled: true, coalesced: false, task: pending }; + } + if (canMergeTask(pending, incoming)) { + mergeTaskInto(pending, incoming); + return { scheduled: false, coalesced: true, task: pending }; + } + markStale(pending, "replaced"); + pending = incoming; + return { scheduled: true, coalesced: false, task: pending }; + }, + start(task = null) { + if (!task || task.stale) return false; + if (pending === task) pending = null; + active = task; + return true; + }, + complete(task = null) { + if (task && active !== task) return false; + active = null; + return true; + }, + drop(task = null, reason = "dropped") { + if (!task) return false; + const target = pending === task ? pending : active === task ? active : null; + if (!target) return false; + markStale(target, reason); + if (pending === task) pending = null; + if (active === task) active = null; + return true; + }, + isStale(task = null, chatId = "") { + if (!task || task.stale) return true; + const currentChatId = String(chatId || "").trim(); + return Boolean(currentChatId && task.chatId && currentChatId !== task.chatId); + }, + }; +} diff --git a/tests/embedding-batch.mjs b/tests/embedding-batch.mjs new file mode 100644 index 0000000..cf62782 --- /dev/null +++ b/tests/embedding-batch.mjs @@ -0,0 +1,80 @@ +import assert from "node:assert/strict"; +import { installResolveHooks, toDataModuleUrl } from "./helpers/register-hooks-compat.mjs"; + +installResolveHooks([ + { specifiers: ["../../../../../script.js"], url: toDataModuleUrl("export function getRequestHeaders() { return {}; }") }, + { specifiers: ["../../../../extensions.js"], url: toDataModuleUrl("export const extension_settings = { st_bme: {} };") }, +]); + +const { embedBatch } = await import("../vector/embedding.js"); + +function jsonResponse(payload) { + return new Response(JSON.stringify(payload), { status: 200, headers: { "Content-Type": "application/json" } }); +} + +async function withFetch(handler, fn) { + const previousFetch = globalThis.fetch; + globalThis.fetch = handler; + try { return await fn(); } finally { globalThis.fetch = previousFetch; } +} + +const plain = (vectors) => vectors.map((vector) => (vector ? Array.from(vector) : null)); + +{ + const calls = []; + await withFetch(async (_url, options = {}) => { + const body = JSON.parse(String(options.body || "{}")); + calls.push(body); + return jsonResponse({ data: body.input.map((text, index) => ({ index, embedding: [String(text).length, index] })) }); + }, async () => { + const vectors = await embedBatch(["alpha", "beta", "gamma"], { mode: "direct", apiUrl: "https://example.com/v1", apiKey: "sk-test", model: "test-embedding", embeddingBatchSize: 2 }); + assert.deepEqual(plain(vectors), [[5, 0], [4, 1], [5, 0]]); + }); + assert.deepEqual(calls.map((call) => call.input), [["alpha", "beta"], ["gamma"]]); +} + +{ + const calls = []; + await withFetch(async (_url, options = {}) => { + const body = JSON.parse(String(options.body || "{}")); + calls.push(body); + if (Array.isArray(body.input)) return new Response("batch schema rejected", { status: 400 }); + return jsonResponse({ data: [{ index: 0, embedding: [String(body.input).length, 9] }] }); + }, async () => { + const vectors = await embedBatch(["first", "second"], { mode: "direct", apiUrl: "https://example.com/v1/embeddings", model: "test-embedding", embeddingBatchSize: 2 }); + assert.deepEqual(plain(vectors), [[5, 9], [6, 9]]); + }); + assert.deepEqual(calls.map((call) => call.input), [["first", "second"], "first", "second"]); +} + +{ + const calls = []; + await withFetch(async (_url, options = {}) => { + const body = JSON.parse(String(options.body || "{}")); + calls.push(body); + if (Array.isArray(body.texts)) return new Response("backend batch rejected", { status: 400 }); + return jsonResponse({ vector: [String(body.text).length, 3] }); + }, async () => { + const vectors = await embedBatch(["uno", "dos"], { mode: "backend", source: "openai", model: "text-embedding-3-small", embeddingBatchSize: 2 }); + assert.deepEqual(plain(vectors), [[3, 3], [3, 3]]); + }); + assert.deepEqual(calls.map((call) => [call.texts, call.text]), [[["uno", "dos"], undefined], [undefined, "uno"], [undefined, "dos"]]); +} + +{ + const calls = []; + await withFetch(async (_url, options = {}) => { + const body = JSON.parse(String(options.body || "{}")); + calls.push(body); + if (Array.isArray(body.input)) { + return jsonResponse({ data: [{ index: 0, embedding: [1, 1] }] }); + } + return jsonResponse({ data: [{ index: 0, embedding: [String(body.input).length, 7] }] }); + }, async () => { + const vectors = await embedBatch(["kept", "fallback"], { mode: "direct", apiUrl: "https://example.com/v1", model: "test-embedding", embeddingBatchSize: 2 }); + assert.deepEqual(plain(vectors), [[1, 1], [8, 7]]); + }); + assert.deepEqual(calls.map((call) => call.input), [["kept", "fallback"], "fallback"]); +} + +console.log("embedding-batch tests passed"); diff --git a/tests/vector-connection-probe.mjs b/tests/vector-connection-probe.mjs new file mode 100644 index 0000000..c3ea3e6 --- /dev/null +++ b/tests/vector-connection-probe.mjs @@ -0,0 +1,58 @@ +import assert from "node:assert/strict"; +import { installResolveHooks, toDataModuleUrl } from "./helpers/register-hooks-compat.mjs"; + +installResolveHooks([ + { specifiers: ["../../../../../script.js"], url: toDataModuleUrl("export function getRequestHeaders() { return {}; }") }, + { specifiers: ["../../../../extensions.js"], url: toDataModuleUrl("export const extension_settings = { st_bme: {} };") }, +]); + +const { testVectorConnection } = await import("../vector/vector-index.js"); + +function jsonResponse(payload) { + return new Response(JSON.stringify(payload), { status: 200, headers: { "Content-Type": "application/json" } }); +} + +async function withFetch(handler, fn) { + const previousFetch = globalThis.fetch; + globalThis.fetch = handler; + try { return await fn(); } finally { globalThis.fetch = previousFetch; } +} + +{ + const calls = []; + const result = await withFetch(async (_url, options = {}) => { + const body = JSON.parse(String(options.body || "{}")); + calls.push(body); + assert.equal(Array.isArray(body.input), true); + return jsonResponse({ data: body.input.map((text, index) => ({ index, embedding: [1, index, String(text).length] })) }); + }, async () => await testVectorConnection({ mode: "direct", apiUrl: "https://example.com/v1", apiKey: "sk-test", model: "test-embedding" })); + assert.equal(result.success, true); + assert.equal(result.dimensions, 3); + assert.equal(result.batchCapable, true); + assert.equal(result.mode, "direct"); + assert.deepEqual(calls[0].input, ["test connection", "runtime batch probe"]); +} + +{ + const calls = []; + const result = await withFetch(async (url, options = {}) => { + const body = JSON.parse(String(options.body || "{}")); + calls.push({ url: String(url), body }); + if (String(url) === "/api/vector/embed") { + assert.equal(Array.isArray(body.texts), true); + return jsonResponse({ vectors: body.texts.map((text, index) => [2, index, String(text).length]) }); + } + assert.equal(String(url), "/api/vector/query"); + return jsonResponse({ hashes: [] }); + }, async () => await testVectorConnection({ mode: "backend", source: "openai", model: "text-embedding-3-small" })); + assert.equal(result.success, true); + assert.equal(result.dimensions, 3); + assert.equal(result.batchCapable, true); + assert.equal(result.vectorStoreCapable, true); + assert.equal(result.mode, "backend"); + assert.deepEqual(calls[0].body.texts, ["test connection", "runtime batch probe"]); + assert.equal(calls[1].url, "/api/vector/query"); + assert.equal(calls[1].body.searchText, "test connection"); +} + +console.log("vector-connection-probe tests passed"); diff --git a/tests/vector-sync-coalescer.mjs b/tests/vector-sync-coalescer.mjs new file mode 100644 index 0000000..93362da --- /dev/null +++ b/tests/vector-sync-coalescer.mjs @@ -0,0 +1,44 @@ +import assert from "node:assert/strict"; +import { + createVectorSyncCoalescer, + mergeVectorSyncRange, + normalizeVectorSyncRange, +} from "../runtime/vector-sync-coalescer.js"; + +assert.deepEqual(normalizeVectorSyncRange({ start: 9, end: 3 }), { start: 3, end: 9 }); +assert.equal(normalizeVectorSyncRange({ start: "x", end: 3 }), null); +assert.deepEqual(mergeVectorSyncRange({ start: 2, end: 4 }, { start: 9, end: 6 }), { start: 2, end: 9 }); +assert.equal(mergeVectorSyncRange(null, { start: 1, end: 2 }), null); + +const coalescer = createVectorSyncCoalescer(); +const first = coalescer.enqueue({ id: "first", chatId: "chat-a", modelScope: "direct:model", range: { start: 4, end: 8 }, mode: "balanced", reason: "after-extraction" }); +assert.equal(first.scheduled, true); +const second = coalescer.enqueue({ id: "second", chatId: "chat-a", modelScope: "direct:model", range: { start: 1, end: 2 }, mode: "fast", reason: "after-edit" }); +assert.equal(second.scheduled, false); +assert.equal(second.coalesced, true); +assert.equal(second.task.id, "first"); +assert.deepEqual(second.task.range, { start: 1, end: 8 }); +assert.equal(second.task.mode, "fast"); + +assert.equal(coalescer.start(first.task), true); +const third = coalescer.enqueue({ id: "third", chatId: "chat-a", modelScope: "direct:model", range: { start: 10, end: 12 } }); +assert.equal(third.scheduled, true); +assert.equal(third.task.id, "third"); +const fourth = coalescer.enqueue({ id: "fourth", chatId: "chat-a", modelScope: "direct:model", range: { start: 20, end: 21 } }); +assert.equal(fourth.scheduled, false); +assert.deepEqual(third.task.range, { start: 10, end: 21 }); + +coalescer.clear("chat-changed"); +assert.equal(coalescer.isStale(first.task, "chat-a"), true); +assert.equal(coalescer.isStale(third.task, "chat-a"), true); + +const rejected = createVectorSyncCoalescer(); +const rejectedFirst = rejected.enqueue({ id: "rejected-first", chatId: "chat-a", modelScope: "direct:model" }); +assert.equal(rejected.drop(rejectedFirst.task, "queue-full"), true); +assert.equal(rejected.getPending(), null, "drop returns pending state to empty after queue rejection"); +assert.equal(rejected.isStale(rejectedFirst.task, "chat-a"), true); +const rejectedSecond = rejected.enqueue({ id: "rejected-second", chatId: "chat-a", modelScope: "direct:model" }); +assert.equal(rejectedSecond.scheduled, true, "new task should schedule after rejected pending is dropped"); +assert.equal(rejectedSecond.task.id, "rejected-second"); + +console.log("vector-sync-coalescer tests passed"); diff --git a/vector/authority-vector-primary-adapter.js b/vector/authority-vector-primary-adapter.js index 6beb906..ade3bde 100644 --- a/vector/authority-vector-primary-adapter.js +++ b/vector/authority-vector-primary-adapter.js @@ -4,7 +4,7 @@ import { AuthorityHttpClient, AuthorityHttpError, } from "../runtime/authority-http-client.js"; -import { embedText } from "./embedding.js"; +import { embedBatch } from "./embedding.js"; export const AUTHORITY_VECTOR_MODE = "authority"; export const AUTHORITY_VECTOR_SOURCE = "authority-trivium"; @@ -15,6 +15,7 @@ const MAX_AUTHORITY_VECTOR_CHUNK_SIZE = 2000; const DEFAULT_AUTHORITY_PURGE_PAGE_SIZE = 200; const DEFAULT_AUTHORITY_PURGE_MAX_PAGES = 1000; const DEFAULT_AUTHORITY_EMBEDDING_BACKEND_SOURCE = "openai"; +const AUTHORITY_CONNECTION_PROBE_TEXTS = ["test connection", "runtime batch probe"]; function clampInteger(value, fallback, min, max) { const parsed = Number(value); @@ -942,9 +943,19 @@ export async function searchAuthorityTriviumNodes(graph, text, config = {}, opti } export async function testAuthorityTriviumConnection(config = {}, options = {}) { - const probeVector = await embedText("test connection", config, { isQuery: true }); + const probeVectors = await embedBatch(AUTHORITY_CONNECTION_PROBE_TEXTS, config, { + isQuery: true, + }); + const probeVector = probeVectors.find((vector) => vector && vector.length > 0); if (!probeVector || probeVector.length === 0) { - return { success: false, dimensions: 0, error: "Embedding API 返回空结果" }; + return { + success: false, + dimensions: 0, + error: "Embedding API 批量返回空结果", + batchCapable: false, + mode: AUTHORITY_VECTOR_MODE, + source: AUTHORITY_VECTOR_SOURCE, + }; } const client = createAuthorityTriviumClient(config, options); await callClient(client, ["stat"], "stat", { @@ -952,5 +963,12 @@ export async function testAuthorityTriviumConnection(config = {}, options = {}) collectionId: options.collectionId, chatId: options.chatId, }); - return { success: true, dimensions: probeVector.length, error: "" }; + return { + success: true, + dimensions: probeVector.length, + error: "", + batchCapable: true, + mode: AUTHORITY_VECTOR_MODE, + source: AUTHORITY_VECTOR_SOURCE, + }; } diff --git a/vector/embedding.js b/vector/embedding.js index d90f2d7..f52ddfc 100644 --- a/vector/embedding.js +++ b/vector/embedding.js @@ -12,6 +12,8 @@ import { resolveConfiguredTimeoutMs } from "../runtime/request-timeout.js"; const MODULE_NAME = "st_bme"; const EMBEDDING_REQUEST_TIMEOUT_MS = 300000; +const DEFAULT_EMBEDDING_BATCH_SIZE = 10; +const MAX_EMBEDDING_BATCH_SIZE = 100; const BACKEND_SOURCES_REQUIRING_API_URL = new Set([ "ollama", "llamacpp", @@ -110,6 +112,94 @@ async function requestBackendEmbeddings(config = {}, payload = {}, { signal } = return await response.json().catch(() => ({})); } +function getEmbeddingBatchSize(config = {}) { + const parsed = Number(config?.embeddingBatchSize ?? config?.batchSize); + if (!Number.isFinite(parsed) || parsed <= 0) { + return DEFAULT_EMBEDDING_BATCH_SIZE; + } + return Math.min(MAX_EMBEDDING_BATCH_SIZE, Math.max(1, Math.trunc(parsed))); +} + +function chunkTexts(texts = [], size = DEFAULT_EMBEDDING_BATCH_SIZE) { + const chunks = []; + for (let start = 0; start < texts.length; start += size) { + chunks.push({ start, texts: texts.slice(start, start + size) }); + } + return chunks; +} + +async function requestDirectEmbeddingBatch(texts, config = {}, { signal } = {}) { + const apiUrl = normalizeOpenAICompatibleBaseUrl(config?.apiUrl); + const response = await fetchWithTimeout( + apiUrl + "/embeddings", + { + method: "POST", + headers: { + "Content-Type": "application/json", + ...(config.apiKey ? { Authorization: "Bearer " + config.apiKey } : {}), + }, + signal, + body: JSON.stringify({ + model: config.model, + input: texts, + }), + }, + getConfiguredTimeoutMs(config), + ); + + if (!response.ok) { + const errorText = await response.text().catch(() => response.statusText); + const error = new Error(errorText || response.statusText || "HTTP " + response.status); + error.status = response.status; + throw error; + } + + const data = await response.json().catch(() => ({})); + const embeddings = Array.isArray(data?.data) ? data.data : null; + if (!embeddings) { + throw new Error("Embedding API 返回格式异常"); + } + + const results = new Array(texts.length).fill(null); + embeddings.forEach((item, order) => { + const rawIndex = Number(item?.index); + const index = Number.isInteger(rawIndex) ? rawIndex : order; + if (index >= 0 && index < results.length) { + results[index] = normalizeVector(item?.embedding); + } + }); + return results; +} + +async function requestBackendEmbeddingBatch(texts, config = {}, { signal, isQuery = false } = {}) { + const payload = await requestBackendEmbeddings( + config, + { texts, isQuery }, + { signal }, + ); + const vectors = Array.isArray(payload?.vectors) ? payload.vectors : null; + if (!vectors) { + throw new Error("Backend Embedding API 返回格式异常"); + } + return texts.map((_, index) => normalizeVector(vectors[index])); +} + +async function fallbackEmbedChunkTexts(texts, config = {}, { signal, isQuery = false } = {}) { + const vectors = []; + for (const text of texts) { + try { + vectors.push(await embedText(text, config, { signal, isQuery })); + } catch (error) { + if (isAbortError(error)) { + throw error; + } + console.error("[ST-BME] Embedding 单条回退失败:", error); + vectors.push(null); + } + } + return vectors; +} + function createCombinedAbortSignal(...signals) { const validSignals = signals.filter(Boolean); if (validSignals.length <= 1) { @@ -264,91 +354,78 @@ export async function embedText(text, config, { signal, isQuery = false } = {}) * @returns {Promise<(Float64Array|null)[]>} */ export async function embedBatch(texts, config, { signal, isQuery = false } = {}) { + const normalizedTexts = Array.isArray(texts) + ? texts.map((item) => String(item ?? "")) + : []; const override = getEmbeddingTestOverride("embedBatch"); if (override) { - return await override(texts, config, { signal, isQuery }); + return await override(normalizedTexts, config, { signal, isQuery }); } - if (readEmbeddingMode(config) === "backend") { - if (!texts.length || !config?.model) { - return texts.map(() => null); - } - try { - const payload = await requestBackendEmbeddings( - config, - { texts, isQuery }, - { signal }, - ); - const vectors = Array.isArray(payload?.vectors) ? payload.vectors : []; - return texts.map((_, index) => normalizeVector(vectors[index])); - } catch (e) { - if (isAbortError(e)) { - throw e; - } - console.error("[ST-BME] Backend Embedding 批量调用失败:", e); - return texts.map(() => null); - } + if (!normalizedTexts.length) { + return []; } + const isBackend = readEmbeddingMode(config) === "backend"; const apiUrl = normalizeOpenAICompatibleBaseUrl(config?.apiUrl); - if (!texts.length || !apiUrl || !config?.model) { - return texts.map(() => null); + if (!config?.model || (!isBackend && !apiUrl)) { + return normalizedTexts.map(() => null); } - try { - const response = await fetchWithTimeout( - `${apiUrl}/embeddings`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - ...(config.apiKey - ? { Authorization: `Bearer ${config.apiKey}` } - : {}), - }, - signal, - body: JSON.stringify({ - model: config.model, - input: texts, - }), - }, - getConfiguredTimeoutMs(config), - ); - - if (!response.ok) { - const errorText = await response.text(); - console.error( - `[ST-BME] Embedding API 批量错误 (${response.status}):`, - errorText, - ); - return texts.map(() => null); - } - - const data = await response.json(); - const embeddings = data?.data; - - if (!Array.isArray(embeddings)) { - return texts.map(() => null); - } - - // 按 index 排序(API 可能不保证顺序) - embeddings.sort((a, b) => a.index - b.index); - - return embeddings.map((item) => { - if (item?.embedding && Array.isArray(item.embedding)) { - return new Float64Array(item.embedding); + const results = new Array(normalizedTexts.length).fill(null); + const batchSize = getEmbeddingBatchSize(config); + for (const chunk of chunkTexts(normalizedTexts, batchSize)) { + let vectors = null; + try { + vectors = isBackend + ? await requestBackendEmbeddingBatch(chunk.texts, config, { signal, isQuery }) + : await requestDirectEmbeddingBatch(chunk.texts, config, { signal }); + } catch (error) { + if (isAbortError(error)) { + throw error; } - return null; - }); - } catch (e) { - if (isAbortError(e)) { - throw e; + console.error( + isBackend + ? "[ST-BME] Backend Embedding 批量调用失败:" + : "[ST-BME] Embedding API 批量调用失败:", + error, + ); } - console.error("[ST-BME] Embedding API 批量调用失败:", e); - return texts.map(() => null); - } -} + if (!vectors || vectors.length < chunk.texts.length) { + vectors = await fallbackEmbedChunkTexts(chunk.texts, config, { + signal, + isQuery, + }); + } else { + const missingIndexes = []; + for (let index = 0; index < chunk.texts.length; index++) { + if (!vectors[index]) { + missingIndexes.push(index); + } + } + if (missingIndexes.length > 0) { + const fallbackVectors = await fallbackEmbedChunkTexts( + missingIndexes.map((index) => chunk.texts[index]), + config, + { + signal, + isQuery, + }, + ); + missingIndexes.forEach((missingIndex, fallbackIndex) => { + vectors[missingIndex] = fallbackVectors[fallbackIndex] || null; + }); + } + } + + for (let index = 0; index < chunk.texts.length; index++) { + results[chunk.start + index] = vectors[index] || null; + } + } + + return results; +} /** * 计算两个向量的 cosine 相似度 * diff --git a/vector/vector-index.js b/vector/vector-index.js index 0697d70..dcd6d14 100644 --- a/vector/vector-index.js +++ b/vector/vector-index.js @@ -64,6 +64,8 @@ function getConfiguredTimeoutMs(config = {}) { })(); } +const VECTOR_CONNECTION_PROBE_TEXTS = ["test connection", "runtime batch probe"]; + const BACKEND_STATUS_MODEL_SOURCES = { openai: "openai", cohere: "cohere", @@ -1457,15 +1459,69 @@ export async function testVectorConnection(config, chatId = "connection-test") { return { success: false, dimensions: 0, error: validation.error }; } - if (isDirectVectorConfig(config)) { + if (isDirectVectorConfig(config) || isBackendVectorConfig(config)) { try { - const vec = await embedText("test connection", config); - if (vec) { - return { success: true, dimensions: vec.length, error: "" }; + const vectors = await embedBatch(VECTOR_CONNECTION_PROBE_TEXTS, config, { + isQuery: true, + }); + const firstVector = vectors.find((vector) => vector && vector.length > 0); + if (firstVector) { + if (isBackendVectorConfig(config)) { + const response = await fetchWithTimeout( + "/api/vector/query", + { + method: "POST", + headers: getRequestHeaders(), + body: JSON.stringify({ + collectionId: buildVectorCollectionId(chatId), + searchText: VECTOR_CONNECTION_PROBE_TEXTS[0], + topK: 1, + threshold: 0, + ...buildBackendSourceRequest(config), + }), + }, + getConfiguredTimeoutMs(config), + ); + const payload = await response.text().catch(() => ""); + if (!response.ok) { + return { + success: false, + dimensions: firstVector.length, + error: payload || response.statusText, + batchCapable: true, + vectorStoreCapable: false, + mode: config.mode, + source: config.source || "backend", + }; + } + } + return { + success: true, + dimensions: firstVector.length, + error: "", + batchCapable: true, + vectorStoreCapable: isBackendVectorConfig(config) ? true : undefined, + mode: config.mode, + source: config.source || "direct", + }; } - return { success: false, dimensions: 0, error: "API 返回空结果" }; + return { + success: false, + dimensions: 0, + error: "批量 Embedding API 返回空结果", + batchCapable: false, + mode: config.mode, + source: config.source || "direct", + }; } catch (error) { - return { success: false, dimensions: 0, error: String(error) }; + return { + success: false, + dimensions: 0, + error: String(error), + batchCapable: false, + mode: config.mode, + source: config.source || "direct", + }; } } @@ -1480,38 +1536,8 @@ export async function testVectorConnection(config, chatId = "connection-test") { } } - try { - const response = await fetchWithTimeout( - "/api/vector/query", - { - method: "POST", - headers: getRequestHeaders(), - body: JSON.stringify({ - collectionId: buildVectorCollectionId(chatId), - searchText: "test connection", - topK: 1, - threshold: 0, - ...buildBackendSourceRequest(config), - }), - }, - getConfiguredTimeoutMs(config), - ); - - const payload = await response.text().catch(() => ""); - if (!response.ok) { - return { - success: false, - dimensions: 0, - error: payload || response.statusText, - }; - } - - return { success: true, dimensions: 0, error: "" }; - } catch (error) { - return { success: false, dimensions: 0, error: String(error) }; - } + return { success: false, dimensions: 0, error: "未知向量配置" }; } - export function getVectorIndexStats(graph) { const state = graph?.vectorIndexState; if (!state) {