diff --git a/index.js b/index.js index 9acade6..6d5b29e 100644 --- a/index.js +++ b/index.js @@ -1269,6 +1269,101 @@ function debugPersistedRecallPersistence( ); } +function buildRecallTargetCandidateHashes(candidateTexts = []) { + const hashes = new Set(); + for (const text of candidateTexts) { + const normalized = normalizeRecallInputText(text); + if (!normalized) continue; + const hash = hashRecallInput(normalized); + if (hash) hashes.add(hash); + } + return hashes; +} + +function doesChatUserMessageMatchRecallCandidates(message, candidateHashes) { + if (!message?.is_user || !(candidateHashes instanceof Set) || !candidateHashes.size) { + return false; + } + const normalizedMessage = normalizeRecallInputText(message?.mes || ""); + if (!normalizedMessage) return false; + return candidateHashes.has(hashRecallInput(normalizedMessage)); +} + +function resolveRecallPersistenceTargetUserMessageIndex( + chat, + { + generationType = "normal", + explicitTargetUserMessageIndex = null, + candidateTexts = [], + preferredRecord = null, + } = {}, +) { + if (!Array.isArray(chat) || chat.length === 0) return null; + + const explicitIndex = Number.isFinite(explicitTargetUserMessageIndex) + ? Math.floor(Number(explicitTargetUserMessageIndex)) + : null; + if (Number.isFinite(explicitIndex) && chat[explicitIndex]?.is_user) { + return explicitIndex; + } + + const candidateHashes = buildRecallTargetCandidateHashes(candidateTexts); + const latestUserIndex = resolveGenerationTargetUserMessageIndex(chat, { + generationType: "history", + }); + + const hasFreshPreferredRecord = isFreshRecallInputRecord(preferredRecord); + const preferredMessageId = + hasFreshPreferredRecord && Number.isFinite(preferredRecord?.messageId) + ? Math.floor(Number(preferredRecord.messageId)) + : null; + + if ( + Number.isFinite(preferredMessageId) && + chat[preferredMessageId]?.is_user && + (!candidateHashes.size || + doesChatUserMessageMatchRecallCandidates( + chat[preferredMessageId], + candidateHashes, + )) + ) { + return preferredMessageId; + } + + if ( + candidateHashes.size && + Number.isFinite(latestUserIndex) && + chat[latestUserIndex]?.is_user && + doesChatUserMessageMatchRecallCandidates( + chat[latestUserIndex], + candidateHashes, + ) + ) { + return latestUserIndex; + } + + if (hasFreshPreferredRecord && candidateHashes.size) { + for (let index = chat.length - 1; index >= 0; index--) { + const message = chat[index]; + if ( + doesChatUserMessageMatchRecallCandidates(message, candidateHashes) + ) { + return index; + } + } + } + + if ( + String(generationType || "normal").trim() !== "normal" && + Number.isFinite(latestUserIndex) && + chat[latestUserIndex]?.is_user + ) { + return latestUserIndex; + } + + return null; +} + function persistRecallInjectionRecord({ recallInput = {}, result = {}, @@ -1280,23 +1375,26 @@ function persistRecallInjectionRecord({ const generationType = String(recallInput?.generationType || "normal").trim() || "normal"; - let resolvedTargetIndex = Number.isFinite(recallInput?.targetUserMessageIndex) - ? recallInput.targetUserMessageIndex - : resolveGenerationTargetUserMessageIndex(chat, { generationType }); - - if ( - !Number.isFinite(resolvedTargetIndex) && - Number.isFinite(lastRecallSentUserMessage?.messageId) && - chat[lastRecallSentUserMessage.messageId]?.is_user - ) { - resolvedTargetIndex = lastRecallSentUserMessage.messageId; - } + let resolvedTargetIndex = resolveRecallPersistenceTargetUserMessageIndex( + chat, + { + generationType, + explicitTargetUserMessageIndex: recallInput?.targetUserMessageIndex, + candidateTexts: [ + recallInput?.userMessage, + recallInput?.overrideUserMessage, + lastRecallSentUserMessage?.text, + ], + preferredRecord: lastRecallSentUserMessage, + }, + ); if (!Number.isFinite(resolvedTargetIndex)) { debugPersistedRecallPersistence("目标 user 楼层解析失败", { generationType, explicitTargetUserMessageIndex: recallInput?.targetUserMessageIndex, lastSentUserMessageId: lastRecallSentUserMessage?.messageId, + recallInputSource: String(recallInput?.source || ""), }); return null; } @@ -1338,6 +1436,7 @@ function persistRecallInjectionRecord({ } triggerChatMetadataSave(getContext(), { immediate: false }); + schedulePersistedRecallMessageUiRefresh(); debugPersistedRecallPersistence( "召回记录已写入 user 楼层", { @@ -1567,16 +1666,19 @@ function applyFinalRecallInjectionForGeneration({ return emptyResolution; } - targetUserMessageIndex = resolveGenerationTargetUserMessageIndex(chat, { + targetUserMessageIndex = resolveRecallPersistenceTargetUserMessageIndex(chat, { generationType, + explicitTargetUserMessageIndex: + transaction?.frozenRecallOptions?.targetUserMessageIndex, + candidateTexts: [ + transaction?.frozenRecallOptions?.overrideUserMessage, + recallResult?.recallInput, + recallResult?.userMessage, + recallResult?.sourceCandidates?.[0]?.text, + lastRecallSentUserMessage?.text, + ], + preferredRecord: lastRecallSentUserMessage, }); - if ( - !Number.isFinite(targetUserMessageIndex) && - Number.isFinite(lastRecallSentUserMessage?.messageId) && - chat[lastRecallSentUserMessage.messageId]?.is_user - ) { - targetUserMessageIndex = lastRecallSentUserMessage.messageId; - } const persistedRecord = Number.isFinite(targetUserMessageIndex) ? readPersistedRecallFromUserMessage(chat, targetUserMessageIndex) diff --git a/llm.js b/llm.js index 6aee979..2931f8d 100644 --- a/llm.js +++ b/llm.js @@ -1446,6 +1446,7 @@ export async function callLLMForJSON({ promptMessages = [], debugContext = null, onStreamProgress = null, + maxCompletionTokens = null, returnFailureDetails = false, } = {}) { const override = getLlmTestOverride("callLLMForJSON"); @@ -1461,6 +1462,7 @@ export async function callLLMForJSON({ promptMessages, debugContext, onStreamProgress, + maxCompletionTokens, returnFailureDetails, }); } @@ -1489,7 +1491,9 @@ export async function callLLMForJSON({ taskType, requestSource: privateRequestSource, onStreamProgress, - maxCompletionTokens: DEFAULT_JSON_COMPLETION_TOKENS, + maxCompletionTokens: Number.isFinite(maxCompletionTokens) + ? maxCompletionTokens + : DEFAULT_JSON_COMPLETION_TOKENS, }); const responseText = response?.content || ""; const outputCleanup = applyTaskOutputRegexStages(taskType, responseText); diff --git a/retriever.js b/retriever.js index e669b49..68dc818 100644 --- a/retriever.js +++ b/retriever.js @@ -1327,7 +1327,7 @@ async function llmRecall( const llmResult = await callLLMForJSON({ systemPrompt: resolveTaskLlmSystemPrompt(promptPayload, systemPrompt), userPrompt: promptPayload.userPrompt, - maxRetries: 1, + maxRetries: 2, signal, taskType: "recall", debugContext: createTaskLlmDebugContext( @@ -1337,6 +1337,7 @@ async function llmRecall( promptMessages: promptPayload.promptMessages, additionalMessages: promptPayload.additionalMessages, onStreamProgress, + maxCompletionTokens: Math.max(512, maxNodes * 160), returnFailureDetails: true, }); const result = llmResult?.ok ? llmResult.data : null; diff --git a/tests/p0-regressions.mjs b/tests/p0-regressions.mjs index c3aa759..4b1027a 100644 --- a/tests/p0-regressions.mjs +++ b/tests/p0-regressions.mjs @@ -3532,6 +3532,70 @@ async function testPersistentRecallSourceResolutionAndTargetRouting() { assert.equal(fallback.injectionText, "persisted"); } +async function testGenerationRecallFinalInjectionRebindsLatestMatchingUserFloor() { + { + const harness = await createGenerationRecallHarness({ realApplyFinal: true }); + harness.chat = [ + { is_user: true, mes: "当前输入" }, + { is_user: false, mes: "assistant-tail" }, + ]; + harness.result.recordRecallSentUserMessage(0, "当前输入", "message-sent"); + + const resolution = + harness.result.applyFinalRecallInjectionForGeneration({ + generationType: "normal", + hookName: "GENERATION_AFTER_COMMANDS", + freshRecallResult: { + status: "completed", + didRecall: true, + injectionText: "fresh-memory", + }, + transaction: { + frozenRecallOptions: { + generationType: "normal", + targetUserMessageIndex: null, + overrideUserMessage: "当前输入", + }, + }, + }); + + assert.equal(resolution.targetUserMessageIndex, 0); + } + + { + const harness = await createGenerationRecallHarness({ realApplyFinal: true }); + harness.chat = [ + { is_user: true, mes: "尾部 user 仍可匹配" }, + { is_user: false, mes: "assistant-tail" }, + ]; + + const resolution = + harness.result.applyFinalRecallInjectionForGeneration({ + generationType: "normal", + hookName: "GENERATION_AFTER_COMMANDS", + freshRecallResult: { + status: "completed", + didRecall: true, + injectionText: "fresh-memory", + sourceCandidates: [ + { + text: "尾部 user 仍可匹配", + }, + ], + }, + transaction: { + frozenRecallOptions: { + generationType: "normal", + targetUserMessageIndex: null, + overrideUserMessage: "尾部 user 仍可匹配", + }, + }, + }); + + assert.equal(resolution.targetUserMessageIndex, 0); + } +} + async function testRecallSubGraphAndDataLayerEntryPoints() { // Sub-graph build test (pure function, no DOM needed) const { buildRecallSubGraph } = await import("../recall-message-ui.js"); @@ -4340,6 +4404,7 @@ await testGenerationRecallAppliesFinalInjectionOncePerTransaction(); await testGenerationRecallDeferredRewriteMutatesFinalMesSendPayload(); await testPersistentRecallDataLayerLifecycleAndCompatibility(); await testPersistentRecallSourceResolutionAndTargetRouting(); +await testGenerationRecallFinalInjectionRebindsLatestMatchingUserFloor(); await testRecallCardMountsOnStandardUserMessageDom(); await testRecallCardSkipsMountWithoutStableMessageIndex(); await testRecallCardDelayedDomInsertionEventuallyRenders(); diff --git a/tests/retrieval-config.mjs b/tests/retrieval-config.mjs index ebcd7da..052f724 100644 --- a/tests/retrieval-config.mjs +++ b/tests/retrieval-config.mjs @@ -279,6 +279,8 @@ assert.equal(llmPoolResult.meta.retrieval.diversityApplied, true); assert.equal(llmPoolResult.meta.retrieval.candidatePoolBeforeDpp, 3); assert.equal(llmPoolResult.meta.retrieval.candidatePoolAfterDpp, 2); assert.equal(state.llmOptions[0].returnFailureDetails, true); +assert.equal(state.llmOptions[0].maxRetries, 2); +assert.equal(state.llmOptions[0].maxCompletionTokens, 512); state.vectorCalls.length = 0; state.diffusionCalls.length = 0;