fix(vector): harden runtime embedding sync

This commit is contained in:
opencode
2026-05-15 09:55:00 +00:00
parent e22f9e4e37
commit be38a2c0d2
9 changed files with 652 additions and 143 deletions

View File

@@ -4,7 +4,7 @@ import {
AuthorityHttpClient,
AuthorityHttpError,
} from "../runtime/authority-http-client.js";
import { embedText } from "./embedding.js";
import { embedBatch } from "./embedding.js";
export const AUTHORITY_VECTOR_MODE = "authority";
export const AUTHORITY_VECTOR_SOURCE = "authority-trivium";
@@ -15,6 +15,7 @@ const MAX_AUTHORITY_VECTOR_CHUNK_SIZE = 2000;
const DEFAULT_AUTHORITY_PURGE_PAGE_SIZE = 200;
const DEFAULT_AUTHORITY_PURGE_MAX_PAGES = 1000;
const DEFAULT_AUTHORITY_EMBEDDING_BACKEND_SOURCE = "openai";
const AUTHORITY_CONNECTION_PROBE_TEXTS = ["test connection", "runtime batch probe"];
function clampInteger(value, fallback, min, max) {
const parsed = Number(value);
@@ -942,9 +943,19 @@ export async function searchAuthorityTriviumNodes(graph, text, config = {}, opti
}
export async function testAuthorityTriviumConnection(config = {}, options = {}) {
const probeVector = await embedText("test connection", config, { isQuery: true });
const probeVectors = await embedBatch(AUTHORITY_CONNECTION_PROBE_TEXTS, config, {
isQuery: true,
});
const probeVector = probeVectors.find((vector) => vector && vector.length > 0);
if (!probeVector || probeVector.length === 0) {
return { success: false, dimensions: 0, error: "Embedding API 返回空结果" };
return {
success: false,
dimensions: 0,
error: "Embedding API 批量返回空结果",
batchCapable: false,
mode: AUTHORITY_VECTOR_MODE,
source: AUTHORITY_VECTOR_SOURCE,
};
}
const client = createAuthorityTriviumClient(config, options);
await callClient(client, ["stat"], "stat", {
@@ -952,5 +963,12 @@ export async function testAuthorityTriviumConnection(config = {}, options = {})
collectionId: options.collectionId,
chatId: options.chatId,
});
return { success: true, dimensions: probeVector.length, error: "" };
return {
success: true,
dimensions: probeVector.length,
error: "",
batchCapable: true,
mode: AUTHORITY_VECTOR_MODE,
source: AUTHORITY_VECTOR_SOURCE,
};
}

View File

@@ -12,6 +12,8 @@ import { resolveConfiguredTimeoutMs } from "../runtime/request-timeout.js";
const MODULE_NAME = "st_bme";
const EMBEDDING_REQUEST_TIMEOUT_MS = 300000;
const DEFAULT_EMBEDDING_BATCH_SIZE = 10;
const MAX_EMBEDDING_BATCH_SIZE = 100;
const BACKEND_SOURCES_REQUIRING_API_URL = new Set([
"ollama",
"llamacpp",
@@ -110,6 +112,94 @@ async function requestBackendEmbeddings(config = {}, payload = {}, { signal } =
return await response.json().catch(() => ({}));
}
function getEmbeddingBatchSize(config = {}) {
const parsed = Number(config?.embeddingBatchSize ?? config?.batchSize);
if (!Number.isFinite(parsed) || parsed <= 0) {
return DEFAULT_EMBEDDING_BATCH_SIZE;
}
return Math.min(MAX_EMBEDDING_BATCH_SIZE, Math.max(1, Math.trunc(parsed)));
}
function chunkTexts(texts = [], size = DEFAULT_EMBEDDING_BATCH_SIZE) {
const chunks = [];
for (let start = 0; start < texts.length; start += size) {
chunks.push({ start, texts: texts.slice(start, start + size) });
}
return chunks;
}
async function requestDirectEmbeddingBatch(texts, config = {}, { signal } = {}) {
const apiUrl = normalizeOpenAICompatibleBaseUrl(config?.apiUrl);
const response = await fetchWithTimeout(
apiUrl + "/embeddings",
{
method: "POST",
headers: {
"Content-Type": "application/json",
...(config.apiKey ? { Authorization: "Bearer " + config.apiKey } : {}),
},
signal,
body: JSON.stringify({
model: config.model,
input: texts,
}),
},
getConfiguredTimeoutMs(config),
);
if (!response.ok) {
const errorText = await response.text().catch(() => response.statusText);
const error = new Error(errorText || response.statusText || "HTTP " + response.status);
error.status = response.status;
throw error;
}
const data = await response.json().catch(() => ({}));
const embeddings = Array.isArray(data?.data) ? data.data : null;
if (!embeddings) {
throw new Error("Embedding API 返回格式异常");
}
const results = new Array(texts.length).fill(null);
embeddings.forEach((item, order) => {
const rawIndex = Number(item?.index);
const index = Number.isInteger(rawIndex) ? rawIndex : order;
if (index >= 0 && index < results.length) {
results[index] = normalizeVector(item?.embedding);
}
});
return results;
}
async function requestBackendEmbeddingBatch(texts, config = {}, { signal, isQuery = false } = {}) {
const payload = await requestBackendEmbeddings(
config,
{ texts, isQuery },
{ signal },
);
const vectors = Array.isArray(payload?.vectors) ? payload.vectors : null;
if (!vectors) {
throw new Error("Backend Embedding API 返回格式异常");
}
return texts.map((_, index) => normalizeVector(vectors[index]));
}
async function fallbackEmbedChunkTexts(texts, config = {}, { signal, isQuery = false } = {}) {
const vectors = [];
for (const text of texts) {
try {
vectors.push(await embedText(text, config, { signal, isQuery }));
} catch (error) {
if (isAbortError(error)) {
throw error;
}
console.error("[ST-BME] Embedding 单条回退失败:", error);
vectors.push(null);
}
}
return vectors;
}
function createCombinedAbortSignal(...signals) {
const validSignals = signals.filter(Boolean);
if (validSignals.length <= 1) {
@@ -264,91 +354,78 @@ export async function embedText(text, config, { signal, isQuery = false } = {})
* @returns {Promise<(Float64Array|null)[]>}
*/
export async function embedBatch(texts, config, { signal, isQuery = false } = {}) {
const normalizedTexts = Array.isArray(texts)
? texts.map((item) => String(item ?? ""))
: [];
const override = getEmbeddingTestOverride("embedBatch");
if (override) {
return await override(texts, config, { signal, isQuery });
return await override(normalizedTexts, config, { signal, isQuery });
}
if (readEmbeddingMode(config) === "backend") {
if (!texts.length || !config?.model) {
return texts.map(() => null);
}
try {
const payload = await requestBackendEmbeddings(
config,
{ texts, isQuery },
{ signal },
);
const vectors = Array.isArray(payload?.vectors) ? payload.vectors : [];
return texts.map((_, index) => normalizeVector(vectors[index]));
} catch (e) {
if (isAbortError(e)) {
throw e;
}
console.error("[ST-BME] Backend Embedding 批量调用失败:", e);
return texts.map(() => null);
}
if (!normalizedTexts.length) {
return [];
}
const isBackend = readEmbeddingMode(config) === "backend";
const apiUrl = normalizeOpenAICompatibleBaseUrl(config?.apiUrl);
if (!texts.length || !apiUrl || !config?.model) {
return texts.map(() => null);
if (!config?.model || (!isBackend && !apiUrl)) {
return normalizedTexts.map(() => null);
}
try {
const response = await fetchWithTimeout(
`${apiUrl}/embeddings`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
...(config.apiKey
? { Authorization: `Bearer ${config.apiKey}` }
: {}),
},
signal,
body: JSON.stringify({
model: config.model,
input: texts,
}),
},
getConfiguredTimeoutMs(config),
);
if (!response.ok) {
const errorText = await response.text();
console.error(
`[ST-BME] Embedding API 批量错误 (${response.status}):`,
errorText,
);
return texts.map(() => null);
}
const data = await response.json();
const embeddings = data?.data;
if (!Array.isArray(embeddings)) {
return texts.map(() => null);
}
// 按 index 排序API 可能不保证顺序)
embeddings.sort((a, b) => a.index - b.index);
return embeddings.map((item) => {
if (item?.embedding && Array.isArray(item.embedding)) {
return new Float64Array(item.embedding);
const results = new Array(normalizedTexts.length).fill(null);
const batchSize = getEmbeddingBatchSize(config);
for (const chunk of chunkTexts(normalizedTexts, batchSize)) {
let vectors = null;
try {
vectors = isBackend
? await requestBackendEmbeddingBatch(chunk.texts, config, { signal, isQuery })
: await requestDirectEmbeddingBatch(chunk.texts, config, { signal });
} catch (error) {
if (isAbortError(error)) {
throw error;
}
return null;
});
} catch (e) {
if (isAbortError(e)) {
throw e;
console.error(
isBackend
? "[ST-BME] Backend Embedding 批量调用失败:"
: "[ST-BME] Embedding API 批量调用失败:",
error,
);
}
console.error("[ST-BME] Embedding API 批量调用失败:", e);
return texts.map(() => null);
}
}
if (!vectors || vectors.length < chunk.texts.length) {
vectors = await fallbackEmbedChunkTexts(chunk.texts, config, {
signal,
isQuery,
});
} else {
const missingIndexes = [];
for (let index = 0; index < chunk.texts.length; index++) {
if (!vectors[index]) {
missingIndexes.push(index);
}
}
if (missingIndexes.length > 0) {
const fallbackVectors = await fallbackEmbedChunkTexts(
missingIndexes.map((index) => chunk.texts[index]),
config,
{
signal,
isQuery,
},
);
missingIndexes.forEach((missingIndex, fallbackIndex) => {
vectors[missingIndex] = fallbackVectors[fallbackIndex] || null;
});
}
}
for (let index = 0; index < chunk.texts.length; index++) {
results[chunk.start + index] = vectors[index] || null;
}
}
return results;
}
/**
* 计算两个向量的 cosine 相似度
*

View File

@@ -64,6 +64,8 @@ function getConfiguredTimeoutMs(config = {}) {
})();
}
const VECTOR_CONNECTION_PROBE_TEXTS = ["test connection", "runtime batch probe"];
const BACKEND_STATUS_MODEL_SOURCES = {
openai: "openai",
cohere: "cohere",
@@ -1457,15 +1459,69 @@ export async function testVectorConnection(config, chatId = "connection-test") {
return { success: false, dimensions: 0, error: validation.error };
}
if (isDirectVectorConfig(config)) {
if (isDirectVectorConfig(config) || isBackendVectorConfig(config)) {
try {
const vec = await embedText("test connection", config);
if (vec) {
return { success: true, dimensions: vec.length, error: "" };
const vectors = await embedBatch(VECTOR_CONNECTION_PROBE_TEXTS, config, {
isQuery: true,
});
const firstVector = vectors.find((vector) => vector && vector.length > 0);
if (firstVector) {
if (isBackendVectorConfig(config)) {
const response = await fetchWithTimeout(
"/api/vector/query",
{
method: "POST",
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: buildVectorCollectionId(chatId),
searchText: VECTOR_CONNECTION_PROBE_TEXTS[0],
topK: 1,
threshold: 0,
...buildBackendSourceRequest(config),
}),
},
getConfiguredTimeoutMs(config),
);
const payload = await response.text().catch(() => "");
if (!response.ok) {
return {
success: false,
dimensions: firstVector.length,
error: payload || response.statusText,
batchCapable: true,
vectorStoreCapable: false,
mode: config.mode,
source: config.source || "backend",
};
}
}
return {
success: true,
dimensions: firstVector.length,
error: "",
batchCapable: true,
vectorStoreCapable: isBackendVectorConfig(config) ? true : undefined,
mode: config.mode,
source: config.source || "direct",
};
}
return { success: false, dimensions: 0, error: "API 返回空结果" };
return {
success: false,
dimensions: 0,
error: "批量 Embedding API 返回空结果",
batchCapable: false,
mode: config.mode,
source: config.source || "direct",
};
} catch (error) {
return { success: false, dimensions: 0, error: String(error) };
return {
success: false,
dimensions: 0,
error: String(error),
batchCapable: false,
mode: config.mode,
source: config.source || "direct",
};
}
}
@@ -1480,38 +1536,8 @@ export async function testVectorConnection(config, chatId = "connection-test") {
}
}
try {
const response = await fetchWithTimeout(
"/api/vector/query",
{
method: "POST",
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: buildVectorCollectionId(chatId),
searchText: "test connection",
topK: 1,
threshold: 0,
...buildBackendSourceRequest(config),
}),
},
getConfiguredTimeoutMs(config),
);
const payload = await response.text().catch(() => "");
if (!response.ok) {
return {
success: false,
dimensions: 0,
error: payload || response.statusText,
};
}
return { success: true, dimensions: 0, error: "" };
} catch (error) {
return { success: false, dimensions: 0, error: String(error) };
}
return { success: false, dimensions: 0, error: "未知向量配置" };
}
export function getVectorIndexStats(graph) {
const state = graph?.vectorIndexState;
if (!state) {