diff --git a/index.js b/index.js index 6a83d38..6f17a0b 100644 --- a/index.js +++ b/index.js @@ -23490,6 +23490,7 @@ async function onTestEmbedding() { return await onTestEmbeddingController({ getCurrentChatId, getEmbeddingConfig, + getSettings, testVectorConnection, toastr, validateVectorConfig, diff --git a/tests/embedding-batch.mjs b/tests/embedding-batch.mjs index cf62782..8081412 100644 --- a/tests/embedding-batch.mjs +++ b/tests/embedding-batch.mjs @@ -31,6 +31,7 @@ const plain = (vectors) => vectors.map((vector) => (vector ? Array.from(vector) assert.deepEqual(plain(vectors), [[5, 0], [4, 1], [5, 0]]); }); assert.deepEqual(calls.map((call) => call.input), [["alpha", "beta"], ["gamma"]]); + assert.deepEqual(calls.map((call) => call.encoding_format), ["float", "float"]); } { @@ -77,4 +78,24 @@ const plain = (vectors) => vectors.map((vector) => (vector ? Array.from(vector) assert.deepEqual(calls.map((call) => call.input), [["kept", "fallback"], "fallback"]); } +{ + const result = await withFetch(async () => jsonResponse({ data: [{ index: 0, embedding: [] }] }), async () => { + await assert.rejects( + () => embedBatch(["empty"], { mode: "direct", apiUrl: "https://example.com/v1", model: "test-embedding", throwOnEmptyBatch: true }), + /Embedding API 批量返回空结果/, + ); + return true; + }); + assert.equal(result, true); +} + +{ + await withFetch(async () => new Response(JSON.stringify({ code: 20012, message: "Model does not exist", data: null }), { status: 400 }), async () => { + await assert.rejects( + () => embedBatch(["bad model"], { mode: "direct", apiUrl: "https://example.com/v1", model: "missing", throwOnEmptyBatch: true }), + /Embedding API 错误 \(400\): Model does not exist/, + ); + }); +} + console.log("embedding-batch tests passed"); diff --git a/tests/ui-actions-embedding.mjs b/tests/ui-actions-embedding.mjs new file mode 100644 index 0000000..7aab577 --- /dev/null +++ b/tests/ui-actions-embedding.mjs @@ -0,0 +1,61 @@ +import assert from "node:assert/strict"; + +const { onTestEmbeddingController } = await import("../ui/ui-actions-controller.js"); + +{ + const calls = []; + const toasts = []; + await onTestEmbeddingController({ + getSettings: () => ({ embeddingTransportMode: "direct" }), + getEmbeddingConfig: (mode) => { + calls.push(["getEmbeddingConfig", mode]); + return { mode, apiUrl: "https://example.com/v1", model: "embedding" }; + }, + validateVectorConfig: () => ({ valid: true, error: "" }), + getCurrentChatId: () => "chat-a", + testVectorConnection: async (config, chatId) => { + calls.push(["testVectorConnection", config.mode, chatId]); + return { success: true, dimensions: 3 }; + }, + toastr: { + info: (message) => toasts.push(["info", message]), + success: (message) => toasts.push(["success", message]), + error: (message) => toasts.push(["error", message]), + warning: (message) => toasts.push(["warning", message]), + }, + }); + assert.deepEqual(calls, [ + ["getEmbeddingConfig", "direct"], + ["testVectorConnection", "direct", "chat-a"], + ]); + assert.equal(toasts.some(([kind, message]) => kind === "info" && message.includes("直连")), true); +} + +{ + const calls = []; + await onTestEmbeddingController({ + getSettings: () => ({ embeddingTransportMode: "backend" }), + getEmbeddingConfig: (mode) => { + calls.push(["getEmbeddingConfig", mode]); + return { mode, source: "openai", model: "embedding" }; + }, + validateVectorConfig: () => ({ valid: true, error: "" }), + getCurrentChatId: () => "chat-b", + testVectorConnection: async (config, chatId) => { + calls.push(["testVectorConnection", config.mode, chatId]); + return { success: true, dimensions: 3 }; + }, + toastr: { + info: () => {}, + success: () => {}, + error: () => {}, + warning: () => {}, + }, + }); + assert.deepEqual(calls, [ + ["getEmbeddingConfig", "backend"], + ["testVectorConnection", "backend", "chat-b"], + ]); +} + +console.log("ui-actions-embedding tests passed"); diff --git a/tests/vector-connection-probe.mjs b/tests/vector-connection-probe.mjs index c3ea3e6..44c0ca5 100644 --- a/tests/vector-connection-probe.mjs +++ b/tests/vector-connection-probe.mjs @@ -24,6 +24,7 @@ async function withFetch(handler, fn) { const body = JSON.parse(String(options.body || "{}")); calls.push(body); assert.equal(Array.isArray(body.input), true); + assert.equal(body.encoding_format, "float"); return jsonResponse({ data: body.input.map((text, index) => ({ index, embedding: [1, index, String(text).length] })) }); }, async () => await testVectorConnection({ mode: "direct", apiUrl: "https://example.com/v1", apiKey: "sk-test", model: "test-embedding" })); assert.equal(result.success, true); @@ -55,4 +56,24 @@ async function withFetch(handler, fn) { assert.equal(calls[1].body.searchText, "test connection"); } +{ + const result = await withFetch(async () => new Response( + JSON.stringify({ code: 20012, message: "Model does not exist. Please check it carefully.", data: null }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ), async () => await testVectorConnection({ mode: "direct", apiUrl: "https://example.com/v1", apiKey: "sk-test", model: "missing-model" })); + assert.equal(result.success, false); + assert.match(result.error, /Model does not exist/); + assert.equal(result.batchCapable, false); +} + +{ + const result = await withFetch(async () => new Response( + JSON.stringify({ error: { message: "Backend provider refused embedding model" } }), + { status: 502, headers: { "Content-Type": "application/json" } }, + ), async () => await testVectorConnection({ mode: "backend", source: "openai", model: "bad-backend-model" })); + assert.equal(result.success, false); + assert.match(result.error, /Backend provider refused embedding model/); + assert.equal(result.batchCapable, false); +} + console.log("vector-connection-probe tests passed"); diff --git a/ui/ui-actions-controller.js b/ui/ui-actions-controller.js index 194a64b..f4cd672 100644 --- a/ui/ui-actions-controller.js +++ b/ui/ui-actions-controller.js @@ -178,14 +178,20 @@ export async function onViewGraphController(runtime) { } export async function onTestEmbeddingController(runtime) { - const config = runtime.getEmbeddingConfig(); + const settings = runtime.getSettings?.() || {}; + const selectedMode = settings.embeddingTransportMode === "backend" ? "backend" : "direct"; + const config = runtime.getEmbeddingConfig(selectedMode); const validation = runtime.validateVectorConfig(config); if (!validation.valid) { runtime.toastr.warning(validation.error); return; } - runtime.toastr.info("正在测试 Embedding API 连通性..."); + runtime.toastr.info( + selectedMode === "backend" + ? "正在测试后端 Embedding API 连通性..." + : "正在测试直连 Embedding API 连通性...", + ); const result = await runtime.testVectorConnection(config, runtime.getCurrentChatId()); if (result.success) { diff --git a/vector/authority-vector-primary-adapter.js b/vector/authority-vector-primary-adapter.js index a997e5e..3ff18b9 100644 --- a/vector/authority-vector-primary-adapter.js +++ b/vector/authority-vector-primary-adapter.js @@ -1049,7 +1049,10 @@ export async function searchAuthorityTriviumNodes(graph, text, config = {}, opti } export async function testAuthorityTriviumConnection(config = {}, options = {}) { - const probeVectors = await embedBatch(AUTHORITY_CONNECTION_PROBE_TEXTS, config, { + const probeVectors = await embedBatch(AUTHORITY_CONNECTION_PROBE_TEXTS, { + ...config, + throwOnEmptyBatch: true, + }, { isQuery: true, }); const probeVector = probeVectors.find((vector) => vector && vector.length > 0); diff --git a/vector/embedding.js b/vector/embedding.js index f52ddfc..0e1f37e 100644 --- a/vector/embedding.js +++ b/vector/embedding.js @@ -55,6 +55,46 @@ function normalizeVector(value) { return vector.length ? new Float64Array(vector) : null; } +function summarizePayload(value, maxLength = 360) { + let text = ""; + try { + text = typeof value === "string" ? value : JSON.stringify(value); + } catch { + text = String(value ?? ""); + } + text = String(text || "").replace(/\s+/g, " ").trim(); + return text.length > maxLength ? `${text.slice(0, maxLength)}…` : text; +} + +function readPayloadMessage(value, fallback = "") { + if (value && typeof value === "object") { + const message = value?.error?.message || value?.message || value?.error || value?.detail; + if (message) return String(message); + } + return summarizePayload(value) || fallback; +} + +function parseJsonText(value = "") { + try { + return JSON.parse(String(value || "")); + } catch { + return value; + } +} + +function buildDirectEmbeddingBody(config = {}, input) { + const body = { + model: config.model, + input, + encoding_format: String(config.encodingFormat || config.encoding_format || "float"), + }; + const dimensions = Number(config.dimensions ?? config.embeddingDimensions); + if (Number.isFinite(dimensions) && dimensions > 0) { + body.dimensions = Math.floor(dimensions); + } + return body; +} + function readEmbeddingMode(config = {}) { return String(config?.embeddingMode || config?.mode || "direct").trim().toLowerCase(); } @@ -102,14 +142,20 @@ async function requestBackendEmbeddings(config = {}, payload = {}, { signal } = if (!response.ok) { const errorText = await response.text().catch(() => response.statusText); - console.error( - `[ST-BME] Backend Embedding API 错误 (${response.status}):`, - errorText, - ); - return null; + const payload = parseJsonText(errorText); + const message = `Backend Embedding API 错误 (${response.status}): ${readPayloadMessage(payload, response.statusText)}`; + console.error(`[ST-BME] ${message}`, payload); + const error = new Error(message); + error.status = response.status; + error.payload = payload; + throw error; } - return await response.json().catch(() => ({})); + return await response.json().catch((error) => { + throw new Error( + `Backend Embedding API JSON 解析失败: ${error?.message || error}`, + ); + }); } function getEmbeddingBatchSize(config = {}) { @@ -139,25 +185,26 @@ async function requestDirectEmbeddingBatch(texts, config = {}, { signal } = {}) ...(config.apiKey ? { Authorization: "Bearer " + config.apiKey } : {}), }, signal, - body: JSON.stringify({ - model: config.model, - input: texts, - }), + body: JSON.stringify(buildDirectEmbeddingBody(config, texts)), }, getConfiguredTimeoutMs(config), ); if (!response.ok) { const errorText = await response.text().catch(() => response.statusText); - const error = new Error(errorText || response.statusText || "HTTP " + response.status); + const payload = parseJsonText(errorText); + const error = new Error( + `Embedding API 错误 (${response.status}): ${readPayloadMessage(payload, response.statusText)}`, + ); error.status = response.status; + error.payload = payload; throw error; } const data = await response.json().catch(() => ({})); const embeddings = Array.isArray(data?.data) ? data.data : null; if (!embeddings) { - throw new Error("Embedding API 返回格式异常"); + throw new Error(`Embedding API 返回格式异常: ${summarizePayload(data)}`); } const results = new Array(texts.length).fill(null); @@ -179,20 +226,27 @@ async function requestBackendEmbeddingBatch(texts, config = {}, { signal, isQuer ); const vectors = Array.isArray(payload?.vectors) ? payload.vectors : null; if (!vectors) { - throw new Error("Backend Embedding API 返回格式异常"); + throw new Error(`Backend Embedding API 返回格式异常: ${summarizePayload(payload)}`); } return texts.map((_, index) => normalizeVector(vectors[index])); } -async function fallbackEmbedChunkTexts(texts, config = {}, { signal, isQuery = false } = {}) { +async function fallbackEmbedChunkTexts( + texts, + config = {}, + { signal, isQuery = false, collectErrors = null, throwOnFailure = false } = {}, +) { const vectors = []; for (const text of texts) { try { - vectors.push(await embedText(text, config, { signal, isQuery })); + vectors.push(await embedText(text, { ...config, throwOnFailure }, { signal, isQuery })); } catch (error) { if (isAbortError(error)) { throw error; } + if (Array.isArray(collectErrors)) { + collectErrors.push(error?.message || String(error)); + } console.error("[ST-BME] Embedding 单条回退失败:", error); vectors.push(null); } @@ -288,6 +342,9 @@ export async function embedText(text, config, { signal, isQuery = false } = {}) if (isAbortError(e)) { throw e; } + if (config?.throwOnFailure) { + throw e; + } console.error("[ST-BME] Backend Embedding 调用失败:", e); return null; } @@ -311,28 +368,33 @@ export async function embedText(text, config, { signal, isQuery = false } = {}) : {}), }, signal, - body: JSON.stringify({ - model: config.model, - input: text, - }), + body: JSON.stringify(buildDirectEmbeddingBody(config, text)), }, getConfiguredTimeoutMs(config), ); if (!response.ok) { const errorText = await response.text(); - console.error( - `[ST-BME] Embedding API 错误 (${response.status}):`, - errorText, - ); + const payload = parseJsonText(errorText); + const message = `Embedding API 错误 (${response.status}): ${readPayloadMessage(payload, response.statusText)}`; + console.error(`[ST-BME] ${message}`, payload); + if (config?.throwOnFailure) throw new Error(message); return null; } - const data = await response.json(); + const data = await response.json().catch((error) => { + if (config?.throwOnFailure) { + throw new Error(`Embedding API JSON 解析失败: ${error?.message || error}`); + } + return {}; + }); const vector = data?.data?.[0]?.embedding; if (!vector || !Array.isArray(vector)) { console.error("[ST-BME] Embedding API 返回格式异常:", data); + if (config?.throwOnFailure) { + throw new Error(`Embedding API 返回格式异常: ${summarizePayload(data)}`); + } return null; } @@ -341,6 +403,9 @@ export async function embedText(text, config, { signal, isQuery = false } = {}) if (isAbortError(e)) { throw e; } + if (config?.throwOnFailure) { + throw e; + } console.error("[ST-BME] Embedding API 调用失败:", e); return null; } @@ -373,6 +438,7 @@ export async function embedBatch(texts, config, { signal, isQuery = false } = {} } const results = new Array(normalizedTexts.length).fill(null); + const diagnostics = []; const batchSize = getEmbeddingBatchSize(config); for (const chunk of chunkTexts(normalizedTexts, batchSize)) { let vectors = null; @@ -390,12 +456,15 @@ export async function embedBatch(texts, config, { signal, isQuery = false } = {} : "[ST-BME] Embedding API 批量调用失败:", error, ); + diagnostics.push(error?.message || String(error)); } if (!vectors || vectors.length < chunk.texts.length) { vectors = await fallbackEmbedChunkTexts(chunk.texts, config, { signal, isQuery, + collectErrors: diagnostics, + throwOnFailure: Boolean(config?.throwOnEmptyBatch), }); } else { const missingIndexes = []; @@ -411,6 +480,8 @@ export async function embedBatch(texts, config, { signal, isQuery = false } = {} { signal, isQuery, + collectErrors: diagnostics, + throwOnFailure: Boolean(config?.throwOnEmptyBatch), }, ); missingIndexes.forEach((missingIndex, fallbackIndex) => { @@ -424,6 +495,10 @@ export async function embedBatch(texts, config, { signal, isQuery = false } = {} } } + if (config?.throwOnEmptyBatch && !results.some((vector) => vector && vector.length > 0)) { + throw new Error(diagnostics.find(Boolean) || "Embedding API 批量返回空结果"); + } + return results; } /** diff --git a/vector/vector-index.js b/vector/vector-index.js index e26b6bb..79a2c8a 100644 --- a/vector/vector-index.js +++ b/vector/vector-index.js @@ -1746,7 +1746,10 @@ export async function testVectorConnection(config, chatId = "connection-test") { if (isDirectVectorConfig(config) || isBackendVectorConfig(config)) { try { - const vectors = await embedBatch(VECTOR_CONNECTION_PROBE_TEXTS, config, { + const vectors = await embedBatch(VECTOR_CONNECTION_PROBE_TEXTS, { + ...config, + throwOnEmptyBatch: true, + }, { isQuery: true, }); const firstVector = vectors.find((vector) => vector && vector.length > 0);