Compare commits
6 Commits
v2026.2.21
...
pr-18304
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8b29d5df7 | ||
|
|
ccbd315ab9 | ||
|
|
de98782c37 | ||
|
|
29e6cf24e6 | ||
|
|
162d4ea7ad | ||
|
|
4a11caec4d |
10
.env.example
10
.env.example
@@ -37,6 +37,16 @@ OPENCLAW_GATEWAY_TOKEN=change-me-to-a-long-random-token
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
# GEMINI_API_KEY=...
|
||||
# OPENROUTER_API_KEY=sk-or-...
|
||||
# OPENCLAW_LIVE_OPENAI_KEY=sk-...
|
||||
# OPENCLAW_LIVE_ANTHROPIC_KEY=sk-ant-...
|
||||
# OPENCLAW_LIVE_GEMINI_KEY=...
|
||||
# OPENAI_API_KEY_1=...
|
||||
# ANTHROPIC_API_KEY_1=...
|
||||
# GEMINI_API_KEY_1=...
|
||||
# GOOGLE_API_KEY=...
|
||||
# OPENAI_API_KEYS=sk-1,sk-2
|
||||
# ANTHROPIC_API_KEYS=sk-ant-1,sk-ant-2
|
||||
# GEMINI_API_KEYS=key-1,key-2
|
||||
|
||||
# Optional additional providers
|
||||
# ZAI_API_KEY=...
|
||||
|
||||
@@ -56,6 +56,7 @@ Docs: https://docs.openclaw.ai
|
||||
- Memory/QMD: scope managed collection names per agent and precreate glob-backed collection directories before registration, preventing cross-agent collection clobbering and startup ENOENT failures in fresh workspaces. (#17194) Thanks @jonathanadams96.
|
||||
- Auto-reply/WhatsApp/TUI/Web: when a final assistant message is `NO_REPLY` and a messaging tool send succeeded, mirror the delivered messaging-tool text into session-visible assistant output so TUI/Web no longer show `NO_REPLY` placeholders. (#7010) Thanks @Morrowind-Xie.
|
||||
- Cron: infer `payload.kind="agentTurn"` for model-only `cron.update` payload patches, so partial agent-turn updates do not fail validation when `kind` is omitted. (#15664) Thanks @rodrigouroz.
|
||||
- Memory/FTS: in embedding-provider fallback mode, ensure file/session indexing still writes and refreshes FTS rows so first-run memory search works and stale removed entries are removed from FTS too. Thanks @irchelper.
|
||||
- TUI: make searchable-select filtering and highlight rendering ANSI-aware so queries ignore hidden escape codes and no longer corrupt ANSI styling sequences during match highlighting. (#4519) Thanks @bee4come.
|
||||
- TUI/Windows: coalesce rapid single-line submit bursts in Git Bash into one multiline message as a fallback when bracketed paste is unavailable, preventing pasted multiline text from being split into multiple sends. (#4986) Thanks @adamkane.
|
||||
- TUI: suppress false `(no output)` placeholders for non-local empty final events during concurrent runs, preventing external-channel replies from showing empty assistant bubbles while a local run is still streaming. (#5782) Thanks @LagWizard and @vignesh07.
|
||||
|
||||
@@ -17,6 +17,20 @@ For model selection rules, see [/concepts/models](/concepts/models).
|
||||
- If you set `agents.defaults.models`, it becomes the allowlist.
|
||||
- CLI helpers: `openclaw onboard`, `openclaw models list`, `openclaw models set <provider/model>`.
|
||||
|
||||
## API key rotation
|
||||
|
||||
- Supports generic provider rotation for selected providers.
|
||||
- Configure multiple keys via:
|
||||
- `OPENCLAW_LIVE_<PROVIDER>_KEY` (single live override, highest priority)
|
||||
- `<PROVIDER>_API_KEYS` (comma or semicolon list)
|
||||
- `<PROVIDER>_API_KEY` (primary key)
|
||||
- `<PROVIDER>_API_KEY_*` (numbered list, e.g. `<PROVIDER>_API_KEY_1`)
|
||||
- For Google providers, `GOOGLE_API_KEY` is also included as fallback.
|
||||
- Key selection order preserves priority and deduplicates values.
|
||||
- Requests are retried with the next key only on rate-limit responses (for example `429`, `rate_limit`, `quota`, `resource exhausted`).
|
||||
- Non-rate-limit failures fail immediately; no key rotation is attempted.
|
||||
- When all candidate keys fail, the final error is returned from the last attempt.
|
||||
|
||||
## Built-in providers (pi-ai catalog)
|
||||
|
||||
OpenClaw ships with the pi‑ai catalog. These providers require **no**
|
||||
@@ -26,6 +40,7 @@ OpenClaw ships with the pi‑ai catalog. These providers require **no**
|
||||
|
||||
- Provider: `openai`
|
||||
- Auth: `OPENAI_API_KEY`
|
||||
- Optional rotation: `OPENAI_API_KEYS`, `OPENAI_API_KEY_1`, `OPENAI_API_KEY_2`, plus `OPENCLAW_LIVE_OPENAI_KEY` (single override)
|
||||
- Example model: `openai/gpt-5.1-codex`
|
||||
- CLI: `openclaw onboard --auth-choice openai-api-key`
|
||||
|
||||
@@ -39,6 +54,7 @@ OpenClaw ships with the pi‑ai catalog. These providers require **no**
|
||||
|
||||
- Provider: `anthropic`
|
||||
- Auth: `ANTHROPIC_API_KEY` or `claude setup-token`
|
||||
- Optional rotation: `ANTHROPIC_API_KEYS`, `ANTHROPIC_API_KEY_1`, `ANTHROPIC_API_KEY_2`, plus `OPENCLAW_LIVE_ANTHROPIC_KEY` (single override)
|
||||
- Example model: `anthropic/claude-opus-4-6`
|
||||
- CLI: `openclaw onboard --auth-choice token` (paste setup-token) or `openclaw models auth paste-token --provider anthropic`
|
||||
|
||||
@@ -78,6 +94,7 @@ OpenClaw ships with the pi‑ai catalog. These providers require **no**
|
||||
|
||||
- Provider: `google`
|
||||
- Auth: `GEMINI_API_KEY`
|
||||
- Optional rotation: `GEMINI_API_KEYS`, `GEMINI_API_KEY_1`, `GEMINI_API_KEY_2`, `GOOGLE_API_KEY` fallback, and `OPENCLAW_LIVE_GEMINI_KEY` (single override)
|
||||
- Example model: `google/gemini-3-pro-preview`
|
||||
- CLI: `openclaw onboard --auth-choice gemini-api-key`
|
||||
|
||||
|
||||
@@ -103,6 +103,23 @@ openclaw models status
|
||||
openclaw doctor
|
||||
```
|
||||
|
||||
## API key rotation behavior (gateway)
|
||||
|
||||
Some providers support retrying a request with alternative keys when an API call
|
||||
hits a provider rate limit.
|
||||
|
||||
- Priority order:
|
||||
- `OPENCLAW_LIVE_<PROVIDER>_KEY` (single override)
|
||||
- `<PROVIDER>_API_KEYS`
|
||||
- `<PROVIDER>_API_KEY`
|
||||
- `<PROVIDER>_API_KEY_*`
|
||||
- Google providers also include `GOOGLE_API_KEY` as an additional fallback.
|
||||
- The same key list is deduplicated before use.
|
||||
- OpenClaw retries with the next key only for rate-limit errors (for example
|
||||
`429`, `rate_limit`, `quota`, `resource exhausted`).
|
||||
- Non-rate-limit errors are not retried with alternate keys.
|
||||
- If all keys fail, the final error from the last attempt is returned.
|
||||
|
||||
## Controlling which credential is used
|
||||
|
||||
### Per-session (chat command)
|
||||
|
||||
@@ -91,7 +91,7 @@ Think of the suites as “increasing realism” (and increasing flakiness/cost):
|
||||
- Costs money / uses rate limits
|
||||
- Prefer running narrowed subsets instead of “everything”
|
||||
- Live runs will source `~/.profile` to pick up missing API keys
|
||||
- Anthropic key rotation: set `OPENCLAW_LIVE_ANTHROPIC_KEYS="sk-...,sk-..."` (or `OPENCLAW_LIVE_ANTHROPIC_KEY=sk-...`) or multiple `ANTHROPIC_API_KEY*` vars; tests will retry on rate limits
|
||||
- API key rotation (provider-specific): set `*_API_KEYS` with comma/semicolon format or `*_API_KEY_1`, `*_API_KEY_2` (for example `OPENAI_API_KEYS`, `ANTHROPIC_API_KEYS`, `GEMINI_API_KEYS`) or per-live override via `OPENCLAW_LIVE_*_KEY`; tests retry on rate limit responses.
|
||||
|
||||
## Which suite should I run?
|
||||
|
||||
|
||||
72
src/agents/api-key-rotation.ts
Normal file
72
src/agents/api-key-rotation.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import { formatErrorMessage } from "../infra/errors.js";
|
||||
import { collectProviderApiKeys, isApiKeyRateLimitError } from "./live-auth-keys.js";
|
||||
|
||||
type ApiKeyRetryParams = {
|
||||
apiKey: string;
|
||||
error: unknown;
|
||||
attempt: number;
|
||||
};
|
||||
|
||||
type ExecuteWithApiKeyRotationOptions<T> = {
|
||||
provider: string;
|
||||
apiKeys: string[];
|
||||
execute: (apiKey: string) => Promise<T>;
|
||||
shouldRetry?: (params: ApiKeyRetryParams & { message: string }) => boolean;
|
||||
onRetry?: (params: ApiKeyRetryParams & { message: string }) => void;
|
||||
};
|
||||
|
||||
function dedupeApiKeys(raw: string[]): string[] {
|
||||
const seen = new Set<string>();
|
||||
const keys: string[] = [];
|
||||
for (const value of raw) {
|
||||
const apiKey = value.trim();
|
||||
if (!apiKey || seen.has(apiKey)) {
|
||||
continue;
|
||||
}
|
||||
seen.add(apiKey);
|
||||
keys.push(apiKey);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
export function collectProviderApiKeysForExecution(params: {
|
||||
provider: string;
|
||||
primaryApiKey?: string;
|
||||
}): string[] {
|
||||
const { primaryApiKey, provider } = params;
|
||||
return dedupeApiKeys([primaryApiKey?.trim() ?? "", ...collectProviderApiKeys(provider)]);
|
||||
}
|
||||
|
||||
export async function executeWithApiKeyRotation<T>(
|
||||
params: ExecuteWithApiKeyRotationOptions<T>,
|
||||
): Promise<T> {
|
||||
const keys = dedupeApiKeys(params.apiKeys);
|
||||
if (keys.length === 0) {
|
||||
throw new Error(`No API keys configured for provider "${params.provider}".`);
|
||||
}
|
||||
|
||||
let lastError: unknown;
|
||||
for (let attempt = 0; attempt < keys.length; attempt += 1) {
|
||||
const apiKey = keys[attempt];
|
||||
try {
|
||||
return await params.execute(apiKey);
|
||||
} catch (error) {
|
||||
lastError = error;
|
||||
const message = formatErrorMessage(error);
|
||||
const retryable = params.shouldRetry
|
||||
? params.shouldRetry({ apiKey, error, attempt, message })
|
||||
: isApiKeyRateLimitError(message);
|
||||
|
||||
if (!retryable || attempt + 1 >= keys.length) {
|
||||
break;
|
||||
}
|
||||
|
||||
params.onRetry?.({ apiKey, error, attempt, message });
|
||||
}
|
||||
}
|
||||
|
||||
if (lastError === undefined) {
|
||||
throw new Error(`Failed to run API request for ${params.provider}.`);
|
||||
}
|
||||
throw lastError;
|
||||
}
|
||||
@@ -1,4 +1,47 @@
|
||||
import { normalizeProviderId } from "./model-selection.js";
|
||||
|
||||
const KEY_SPLIT_RE = /[\s,;]+/g;
|
||||
const GOOGLE_LIVE_SINGLE_KEY = "OPENCLAW_LIVE_GEMINI_KEY";
|
||||
|
||||
const PROVIDER_PREFIX_OVERRIDES: Record<string, string> = {
|
||||
google: "GEMINI",
|
||||
"google-vertex": "GEMINI",
|
||||
};
|
||||
|
||||
type ProviderApiKeyConfig = {
|
||||
liveSingle?: string;
|
||||
listVar?: string;
|
||||
primaryVar?: string;
|
||||
prefixedVar?: string;
|
||||
fallbackVars: string[];
|
||||
};
|
||||
|
||||
const PROVIDER_API_KEY_CONFIG: Record<string, Omit<ProviderApiKeyConfig, "fallbackVars">> = {
|
||||
anthropic: {
|
||||
liveSingle: "OPENCLAW_LIVE_ANTHROPIC_KEY",
|
||||
listVar: "OPENCLAW_LIVE_ANTHROPIC_KEYS",
|
||||
primaryVar: "ANTHROPIC_API_KEY",
|
||||
prefixedVar: "ANTHROPIC_API_KEY_",
|
||||
},
|
||||
google: {
|
||||
liveSingle: GOOGLE_LIVE_SINGLE_KEY,
|
||||
listVar: "GEMINI_API_KEYS",
|
||||
primaryVar: "GEMINI_API_KEY",
|
||||
prefixedVar: "GEMINI_API_KEY_",
|
||||
},
|
||||
"google-vertex": {
|
||||
liveSingle: GOOGLE_LIVE_SINGLE_KEY,
|
||||
listVar: "GEMINI_API_KEYS",
|
||||
primaryVar: "GEMINI_API_KEY",
|
||||
prefixedVar: "GEMINI_API_KEY_",
|
||||
},
|
||||
openai: {
|
||||
liveSingle: "OPENCLAW_LIVE_OPENAI_KEY",
|
||||
listVar: "OPENAI_API_KEYS",
|
||||
primaryVar: "OPENAI_API_KEY",
|
||||
prefixedVar: "OPENAI_API_KEY_",
|
||||
},
|
||||
};
|
||||
|
||||
function parseKeyList(raw?: string | null): string[] {
|
||||
if (!raw) {
|
||||
@@ -25,17 +68,53 @@ function collectEnvPrefixedKeys(prefix: string): string[] {
|
||||
return keys;
|
||||
}
|
||||
|
||||
export function collectAnthropicApiKeys(): string[] {
|
||||
const forcedSingle = process.env.OPENCLAW_LIVE_ANTHROPIC_KEY?.trim();
|
||||
function resolveProviderApiKeyConfig(provider: string): ProviderApiKeyConfig {
|
||||
const normalized = normalizeProviderId(provider);
|
||||
const custom = PROVIDER_API_KEY_CONFIG[normalized];
|
||||
const base = PROVIDER_PREFIX_OVERRIDES[normalized] ?? normalized.toUpperCase().replace(/-/g, "_");
|
||||
|
||||
const liveSingle = custom?.liveSingle ?? `OPENCLAW_LIVE_${base}_KEY`;
|
||||
const listVar = custom?.listVar ?? `${base}_API_KEYS`;
|
||||
const primaryVar = custom?.primaryVar ?? `${base}_API_KEY`;
|
||||
const prefixedVar = custom?.prefixedVar ?? `${base}_API_KEY_`;
|
||||
|
||||
if (normalized === "google" || normalized === "google-vertex") {
|
||||
return {
|
||||
liveSingle,
|
||||
listVar,
|
||||
primaryVar,
|
||||
prefixedVar,
|
||||
fallbackVars: ["GOOGLE_API_KEY"],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
liveSingle,
|
||||
listVar,
|
||||
primaryVar,
|
||||
prefixedVar,
|
||||
fallbackVars: [],
|
||||
};
|
||||
}
|
||||
|
||||
export function collectProviderApiKeys(provider: string): string[] {
|
||||
const config = resolveProviderApiKeyConfig(provider);
|
||||
|
||||
const forcedSingle = config.liveSingle ? process.env[config.liveSingle]?.trim() : undefined;
|
||||
if (forcedSingle) {
|
||||
return [forcedSingle];
|
||||
}
|
||||
|
||||
const fromList = parseKeyList(process.env.OPENCLAW_LIVE_ANTHROPIC_KEYS);
|
||||
const fromEnv = collectEnvPrefixedKeys("ANTHROPIC_API_KEY");
|
||||
const primary = process.env.ANTHROPIC_API_KEY?.trim();
|
||||
const fromList = parseKeyList(config.listVar ? process.env[config.listVar] : undefined);
|
||||
const primary = config.primaryVar ? process.env[config.primaryVar]?.trim() : undefined;
|
||||
const fromPrefixed = config.prefixedVar ? collectEnvPrefixedKeys(config.prefixedVar) : [];
|
||||
|
||||
const fallback = config.fallbackVars
|
||||
.map((envVar) => process.env[envVar]?.trim())
|
||||
.filter(Boolean) as string[];
|
||||
|
||||
const seen = new Set<string>();
|
||||
|
||||
const add = (value?: string) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -49,17 +128,26 @@ export function collectAnthropicApiKeys(): string[] {
|
||||
for (const value of fromList) {
|
||||
add(value);
|
||||
}
|
||||
if (primary) {
|
||||
add(primary);
|
||||
add(primary);
|
||||
for (const value of fromPrefixed) {
|
||||
add(value);
|
||||
}
|
||||
for (const value of fromEnv) {
|
||||
for (const value of fallback) {
|
||||
add(value);
|
||||
}
|
||||
|
||||
return Array.from(seen);
|
||||
}
|
||||
|
||||
export function isAnthropicRateLimitError(message: string): boolean {
|
||||
export function collectAnthropicApiKeys(): string[] {
|
||||
return collectProviderApiKeys("anthropic");
|
||||
}
|
||||
|
||||
export function collectGeminiApiKeys(): string[] {
|
||||
return collectProviderApiKeys("google");
|
||||
}
|
||||
|
||||
export function isApiKeyRateLimitError(message: string): boolean {
|
||||
const lower = message.toLowerCase();
|
||||
if (lower.includes("rate_limit")) {
|
||||
return true;
|
||||
@@ -70,9 +158,22 @@ export function isAnthropicRateLimitError(message: string): boolean {
|
||||
if (lower.includes("429")) {
|
||||
return true;
|
||||
}
|
||||
if (lower.includes("quota exceeded") || lower.includes("quota_exceeded")) {
|
||||
return true;
|
||||
}
|
||||
if (lower.includes("resource exhausted") || lower.includes("resource_exhausted")) {
|
||||
return true;
|
||||
}
|
||||
if (lower.includes("too many requests")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export function isAnthropicRateLimitError(message: string): boolean {
|
||||
return isApiKeyRateLimitError(message);
|
||||
}
|
||||
|
||||
export function isAnthropicBillingError(message: string): boolean {
|
||||
const lower = message.toLowerCase();
|
||||
if (lower.includes("credit balance")) {
|
||||
@@ -91,7 +192,7 @@ export function isAnthropicBillingError(message: string): boolean {
|
||||
return true;
|
||||
}
|
||||
if (
|
||||
/["']?(?:status|code)["']?\s*[:=]\s*402\b|\bhttp\s*402\b|\berror(?:\s+code)?\s*[:=]?\s*402\b|\b(?:got|returned|received)\s+(?:a\s+)?402\b|^\s*402\s+payment/i.test(
|
||||
/["']?(?:status|code)["']?\s*[:=]\s*402\b|\bhttp\s*402\b|\berror(?:\s+code)?\s*[:=]?\s*402\b|\b(?:got|returned|received)\s+(?:a\s+)?402\b|^\s*402\spayment/i.test(
|
||||
lower,
|
||||
)
|
||||
) {
|
||||
|
||||
@@ -81,12 +81,14 @@ export function createMemorySearchTool(options: {
|
||||
status.backend === "qmd"
|
||||
? clampResultsByInjectedChars(decorated, resolved.qmd?.limits.maxInjectedChars)
|
||||
: decorated;
|
||||
const searchMode = (status.custom as { searchMode?: string } | undefined)?.searchMode;
|
||||
return jsonResult({
|
||||
results,
|
||||
provider: status.provider,
|
||||
model: status.model,
|
||||
fallback: status.fallback,
|
||||
citations: citationsMode,
|
||||
mode: searchMode,
|
||||
});
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
|
||||
40
src/infra/gemini-auth.ts
Normal file
40
src/infra/gemini-auth.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Shared Gemini authentication utilities.
|
||||
*
|
||||
* Supports both traditional API keys and OAuth JSON format.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Parse Gemini API key and return appropriate auth headers.
|
||||
*
|
||||
* OAuth format: `{"token": "...", "projectId": "..."}`
|
||||
*
|
||||
* @param apiKey - Either a traditional API key string or OAuth JSON
|
||||
* @returns Headers object with appropriate authentication
|
||||
*/
|
||||
export function parseGeminiAuth(apiKey: string): { headers: Record<string, string> } {
|
||||
// Try parsing as OAuth JSON format
|
||||
if (apiKey.startsWith("{")) {
|
||||
try {
|
||||
const parsed = JSON.parse(apiKey) as { token?: string; projectId?: string };
|
||||
if (typeof parsed.token === "string" && parsed.token) {
|
||||
return {
|
||||
headers: {
|
||||
Authorization: `Bearer ${parsed.token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
};
|
||||
}
|
||||
} catch {
|
||||
// Parse failed, fallback to API key mode
|
||||
}
|
||||
}
|
||||
|
||||
// Default: traditional API key
|
||||
return {
|
||||
headers: {
|
||||
"x-goog-api-key": apiKey,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import { normalizeGoogleModelId } from "../../../agents/models-config.providers.js";
|
||||
import { parseGeminiAuth } from "../../../infra/gemini-auth.js";
|
||||
import { assertOkOrThrowHttpError, fetchWithTimeoutGuarded, normalizeBaseUrl } from "../shared.js";
|
||||
|
||||
export async function generateGeminiInlineDataText(params: {
|
||||
@@ -30,12 +31,12 @@ export async function generateGeminiInlineDataText(params: {
|
||||
})();
|
||||
const url = `${baseUrl}/models/${model}:generateContent`;
|
||||
|
||||
const authHeaders = parseGeminiAuth(params.apiKey);
|
||||
const headers = new Headers(params.headers);
|
||||
if (!headers.has("content-type")) {
|
||||
headers.set("content-type", "application/json");
|
||||
}
|
||||
if (!headers.has("x-goog-api-key")) {
|
||||
headers.set("x-goog-api-key", params.apiKey);
|
||||
for (const [key, value] of Object.entries(authHeaders.headers)) {
|
||||
if (!headers.has(key)) {
|
||||
headers.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
const prompt = (() => {
|
||||
|
||||
@@ -14,6 +14,10 @@ import type {
|
||||
MediaUnderstandingOutput,
|
||||
MediaUnderstandingProvider,
|
||||
} from "./types.js";
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { applyTemplate } from "../auto-reply/templating.js";
|
||||
import { logVerbose, shouldLogVerbose } from "../globals.js";
|
||||
@@ -392,7 +396,10 @@ export async function runProviderEntry(params: {
|
||||
preferredProfile: entry.preferredProfile,
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
const apiKey = requireApiKey(auth, providerId);
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: providerId,
|
||||
primaryApiKey: requireApiKey(auth, providerId),
|
||||
});
|
||||
const providerConfig = cfg.models?.providers?.[providerId];
|
||||
const baseUrl = entry.baseUrl ?? params.config?.baseUrl ?? providerConfig?.baseUrl;
|
||||
const mergedHeaders = {
|
||||
@@ -407,18 +414,23 @@ export async function runProviderEntry(params: {
|
||||
entry,
|
||||
});
|
||||
const model = entry.model?.trim() || DEFAULT_AUDIO_MODELS[providerId] || entry.model;
|
||||
const result = await provider.transcribeAudio({
|
||||
buffer: media.buffer,
|
||||
fileName: media.fileName,
|
||||
mime: media.mime,
|
||||
apiKey,
|
||||
baseUrl,
|
||||
headers,
|
||||
model,
|
||||
language: entry.language ?? params.config?.language ?? cfg.tools?.media?.audio?.language,
|
||||
prompt,
|
||||
query: providerQuery,
|
||||
timeoutMs,
|
||||
const result = await executeWithApiKeyRotation({
|
||||
provider: providerId,
|
||||
apiKeys,
|
||||
execute: async (apiKey) =>
|
||||
provider.transcribeAudio({
|
||||
buffer: media.buffer,
|
||||
fileName: media.fileName,
|
||||
mime: media.mime,
|
||||
apiKey,
|
||||
baseUrl,
|
||||
headers,
|
||||
model,
|
||||
language: entry.language ?? params.config?.language ?? cfg.tools?.media?.audio?.language,
|
||||
prompt,
|
||||
query: providerQuery,
|
||||
timeoutMs,
|
||||
}),
|
||||
});
|
||||
return {
|
||||
kind: "audio.transcription",
|
||||
@@ -452,18 +464,26 @@ export async function runProviderEntry(params: {
|
||||
preferredProfile: entry.preferredProfile,
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
const apiKey = requireApiKey(auth, providerId);
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: providerId,
|
||||
primaryApiKey: requireApiKey(auth, providerId),
|
||||
});
|
||||
const providerConfig = cfg.models?.providers?.[providerId];
|
||||
const result = await provider.describeVideo({
|
||||
buffer: media.buffer,
|
||||
fileName: media.fileName,
|
||||
mime: media.mime,
|
||||
apiKey,
|
||||
baseUrl: providerConfig?.baseUrl,
|
||||
headers: providerConfig?.headers,
|
||||
model: entry.model,
|
||||
prompt,
|
||||
timeoutMs,
|
||||
const result = await executeWithApiKeyRotation({
|
||||
provider: providerId,
|
||||
apiKeys,
|
||||
execute: (apiKey) =>
|
||||
provider.describeVideo({
|
||||
buffer: media.buffer,
|
||||
fileName: media.fileName,
|
||||
mime: media.mime,
|
||||
apiKey,
|
||||
baseUrl: providerConfig?.baseUrl,
|
||||
headers: providerConfig?.headers,
|
||||
model: entry.model,
|
||||
prompt,
|
||||
timeoutMs,
|
||||
}),
|
||||
});
|
||||
return {
|
||||
kind: "video.description",
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { isTruthyEnvValue } from "../infra/env.js";
|
||||
import { parseGeminiAuth } from "../infra/gemini-auth.js";
|
||||
import { createSubsystemLogger } from "../logging/subsystem.js";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
@@ -8,6 +13,7 @@ export type GeminiEmbeddingClient = {
|
||||
headers: Record<string, string>;
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
};
|
||||
|
||||
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
|
||||
@@ -73,23 +79,40 @@ export async function createGeminiEmbeddingProvider(
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const res = await fetch(embedUrl, {
|
||||
const fetchWithGeminiAuth = async (apiKey: string, endpoint: string, body: unknown) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...client.headers,
|
||||
};
|
||||
const res = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers: client.headers,
|
||||
body: JSON.stringify({
|
||||
content: { parts: [{ text }] },
|
||||
taskType: "RETRIEVAL_QUERY",
|
||||
}),
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const payload = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
|
||||
}
|
||||
const payload = (await res.json()) as { embedding?: { values?: number[] } };
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const payload = await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: client.apiKeys,
|
||||
execute: (apiKey) =>
|
||||
fetchWithGeminiAuth(apiKey, embedUrl, {
|
||||
content: { parts: [{ text }] },
|
||||
taskType: "RETRIEVAL_QUERY",
|
||||
}),
|
||||
});
|
||||
return payload.embedding?.values ?? [];
|
||||
};
|
||||
|
||||
@@ -102,16 +125,14 @@ export async function createGeminiEmbeddingProvider(
|
||||
content: { parts: [{ text }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
}));
|
||||
const res = await fetch(batchUrl, {
|
||||
method: "POST",
|
||||
headers: client.headers,
|
||||
body: JSON.stringify({ requests }),
|
||||
const payload = await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: client.apiKeys,
|
||||
execute: (apiKey) =>
|
||||
fetchWithGeminiAuth(apiKey, batchUrl, {
|
||||
requests,
|
||||
}),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const payload = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
|
||||
}
|
||||
const payload = (await res.json()) as { embeddings?: Array<{ values?: number[] }> };
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return texts.map((_, index) => embeddings[index]?.values ?? []);
|
||||
};
|
||||
@@ -151,10 +172,12 @@ export async function resolveGeminiEmbeddingClient(
|
||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": apiKey,
|
||||
...headerOverrides,
|
||||
};
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: "google",
|
||||
primaryApiKey: apiKey,
|
||||
});
|
||||
const model = normalizeGeminiModel(options.model);
|
||||
const modelPath = buildGeminiModelPath(model);
|
||||
debugLog("memory embeddings: gemini client", {
|
||||
@@ -165,5 +188,5 @@ export async function resolveGeminiEmbeddingClient(
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, model, modelPath };
|
||||
return { baseUrl, headers, model, modelPath, apiKeys };
|
||||
}
|
||||
|
||||
@@ -459,3 +459,63 @@ describe("local embedding normalization", () => {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("FTS-only fallback when no provider available", () => {
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("returns null provider with reason when auto mode finds no providers", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error('No API key found for provider "openai"'),
|
||||
);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
expect(result.provider).toBeNull();
|
||||
expect(result.requestedProvider).toBe("auto");
|
||||
expect(result.providerUnavailableReason).toBeDefined();
|
||||
expect(result.providerUnavailableReason).toContain("No API key");
|
||||
});
|
||||
|
||||
it("returns null provider when explicit provider fails with missing API key", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error('No API key found for provider "openai"'),
|
||||
);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
expect(result.provider).toBeNull();
|
||||
expect(result.requestedProvider).toBe("openai");
|
||||
expect(result.providerUnavailableReason).toBeDefined();
|
||||
});
|
||||
|
||||
it("returns null provider when both primary and fallback fail with missing API keys", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error("No API key found for provider"),
|
||||
);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini",
|
||||
});
|
||||
|
||||
expect(result.provider).toBeNull();
|
||||
expect(result.requestedProvider).toBe("openai");
|
||||
expect(result.fallbackFrom).toBe("openai");
|
||||
expect(result.providerUnavailableReason).toContain("Fallback to gemini failed");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -36,10 +36,11 @@ export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
|
||||
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage"] as const;
|
||||
|
||||
export type EmbeddingProviderResult = {
|
||||
provider: EmbeddingProvider;
|
||||
provider: EmbeddingProvider | null;
|
||||
requestedProvider: EmbeddingProviderRequest;
|
||||
fallbackFrom?: EmbeddingProviderId;
|
||||
fallbackReason?: string;
|
||||
providerUnavailableReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
gemini?: GeminiEmbeddingClient;
|
||||
voyage?: VoyageEmbeddingClient;
|
||||
@@ -183,15 +184,19 @@ export async function createEmbeddingProvider(
|
||||
missingKeyErrors.push(message);
|
||||
continue;
|
||||
}
|
||||
// Non-auth errors (e.g., network) are still fatal
|
||||
throw new Error(message, { cause: err });
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed due to missing API keys - return null provider for FTS-only mode
|
||||
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
|
||||
if (details.length > 0) {
|
||||
throw new Error(details.join("\n\n"));
|
||||
}
|
||||
throw new Error("No embeddings provider available.");
|
||||
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -209,13 +214,31 @@ export async function createEmbeddingProvider(
|
||||
fallbackReason: reason,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
// oxlint-disable-next-line preserve-caught-error
|
||||
throw new Error(
|
||||
`${reason}\n\nFallback to ${fallback} failed: ${formatErrorMessage(fallbackErr)}`,
|
||||
{ cause: fallbackErr },
|
||||
);
|
||||
// Both primary and fallback failed - check if it's auth-related
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
|
||||
// Both failed due to missing API keys - return null for FTS-only mode
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
providerUnavailableReason: combinedReason,
|
||||
};
|
||||
}
|
||||
// Non-auth errors are still fatal
|
||||
throw new Error(combinedReason, { cause: fallbackErr });
|
||||
}
|
||||
}
|
||||
// No fallback configured - check if we should degrade to FTS-only
|
||||
if (isMissingApiKeyError(primaryErr)) {
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
throw new Error(reason, { cause: primaryErr });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ class MemoryManagerEmbeddingOps {
|
||||
}
|
||||
|
||||
private loadEmbeddingCache(hashes: string[]): Map<string, number[]> {
|
||||
if (!this.cache.enabled) {
|
||||
if (!this.cache.enabled || !this.provider) {
|
||||
return new Map();
|
||||
}
|
||||
if (hashes.length === 0) {
|
||||
@@ -114,7 +114,7 @@ class MemoryManagerEmbeddingOps {
|
||||
}
|
||||
|
||||
private upsertEmbeddingCache(entries: Array<{ hash: string; embedding: number[] }>): void {
|
||||
if (!this.cache.enabled) {
|
||||
if (!this.cache.enabled || !this.provider) {
|
||||
return;
|
||||
}
|
||||
if (entries.length === 0) {
|
||||
@@ -202,6 +202,10 @@ class MemoryManagerEmbeddingOps {
|
||||
}
|
||||
|
||||
private computeProviderKey(): string {
|
||||
// FTS-only mode: no provider, use a constant key
|
||||
if (!this.provider) {
|
||||
return hashText(JSON.stringify({ provider: "none", model: "fts-only" }));
|
||||
}
|
||||
if (this.provider.id === "openai" && this.openAi) {
|
||||
const entries = Object.entries(this.openAi.headers)
|
||||
.filter(([key]) => key.toLowerCase() !== "authorization")
|
||||
@@ -241,6 +245,9 @@ class MemoryManagerEmbeddingOps {
|
||||
entry: MemoryFileEntry | SessionFileEntry,
|
||||
source: MemorySource,
|
||||
): Promise<number[][]> {
|
||||
if (!this.provider) {
|
||||
return this.embedChunksInBatches(chunks);
|
||||
}
|
||||
if (this.provider.id === "openai" && this.openAi) {
|
||||
return this.embedChunksWithOpenAiBatch(chunks, entry, source);
|
||||
}
|
||||
@@ -419,7 +426,7 @@ class MemoryManagerEmbeddingOps {
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: this.openAi?.model ?? this.provider.model,
|
||||
model: this.openAi?.model ?? this.provider?.model ?? "text-embedding-3-small",
|
||||
input: chunk.text,
|
||||
},
|
||||
}),
|
||||
@@ -489,6 +496,9 @@ class MemoryManagerEmbeddingOps {
|
||||
if (texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
if (!this.provider) {
|
||||
throw new Error("Cannot embed batch in FTS-only mode (no embedding provider)");
|
||||
}
|
||||
let attempt = 0;
|
||||
let delayMs = EMBEDDING_RETRY_BASE_DELAY_MS;
|
||||
while (true) {
|
||||
@@ -528,7 +538,7 @@ class MemoryManagerEmbeddingOps {
|
||||
}
|
||||
|
||||
private resolveEmbeddingTimeout(kind: "query" | "batch"): number {
|
||||
const isLocal = this.provider.id === "local";
|
||||
const isLocal = this.provider?.id === "local";
|
||||
if (kind === "query") {
|
||||
return isLocal ? EMBEDDING_QUERY_TIMEOUT_LOCAL_MS : EMBEDDING_QUERY_TIMEOUT_REMOTE_MS;
|
||||
}
|
||||
@@ -536,6 +546,9 @@ class MemoryManagerEmbeddingOps {
|
||||
}
|
||||
|
||||
private async embedQueryWithTimeout(text: string): Promise<number[]> {
|
||||
if (!this.provider) {
|
||||
throw new Error("Cannot embed query in FTS-only mode (no embedding provider)");
|
||||
}
|
||||
const timeoutMs = this.resolveEmbeddingTimeout("query");
|
||||
log.debug("memory embeddings: query start", { provider: this.provider.id, timeoutMs });
|
||||
return await this.withTimeout(
|
||||
@@ -682,20 +695,30 @@ class MemoryManagerEmbeddingOps {
|
||||
options: { source: MemorySource; content?: string },
|
||||
) {
|
||||
const content = options.content ?? (await fs.readFile(entry.absPath, "utf-8"));
|
||||
const chunks = enforceEmbeddingMaxInputTokens(
|
||||
this.provider,
|
||||
chunkMarkdown(content, this.settings.chunking).filter(
|
||||
(chunk) => chunk.text.trim().length > 0,
|
||||
),
|
||||
const parsedChunks = chunkMarkdown(content, this.settings.chunking).filter(
|
||||
(chunk) => chunk.text.trim().length > 0,
|
||||
);
|
||||
const chunks =
|
||||
this.provider === null
|
||||
? parsedChunks
|
||||
: enforceEmbeddingMaxInputTokens(this.provider, parsedChunks);
|
||||
if (options.source === "sessions" && "lineMap" in entry) {
|
||||
remapChunkLines(chunks, entry.lineMap);
|
||||
}
|
||||
const embeddings = this.batch.enabled
|
||||
? await this.embedChunksWithBatch(chunks, entry, options.source)
|
||||
: await this.embedChunksInBatches(chunks);
|
||||
|
||||
const embeddings = this.provider
|
||||
? this.batch.enabled
|
||||
? await this.embedChunksWithBatch(chunks, entry, options.source)
|
||||
: await this.embedChunksInBatches(chunks)
|
||||
: [];
|
||||
|
||||
const sample = embeddings.find((embedding) => embedding.length > 0);
|
||||
const vectorReady = sample ? await this.ensureVectorReady(sample.length) : false;
|
||||
let vectorReady = false;
|
||||
if (this.provider && sample) {
|
||||
vectorReady = await this.ensureVectorReady(sample.length);
|
||||
}
|
||||
const model = this.provider?.model ?? "fts-only";
|
||||
|
||||
const now = Date.now();
|
||||
if (vectorReady) {
|
||||
try {
|
||||
@@ -707,10 +730,16 @@ class MemoryManagerEmbeddingOps {
|
||||
} catch {}
|
||||
}
|
||||
if (this.fts.enabled && this.fts.available) {
|
||||
const deleteFtsSql =
|
||||
this.provider === null
|
||||
? `DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ?`
|
||||
: `DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ? AND model = ?`;
|
||||
const deleteFtsParams =
|
||||
this.provider === null ? [entry.path, options.source] : [entry.path, options.source, model];
|
||||
try {
|
||||
this.db
|
||||
.prepare(`DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ? AND model = ?`)
|
||||
.run(entry.path, options.source, this.provider.model);
|
||||
.prepare(deleteFtsSql)
|
||||
.run(...deleteFtsParams);
|
||||
} catch {}
|
||||
}
|
||||
this.db
|
||||
@@ -720,7 +749,7 @@ class MemoryManagerEmbeddingOps {
|
||||
const chunk = chunks[i];
|
||||
const embedding = embeddings[i] ?? [];
|
||||
const id = hashText(
|
||||
`${options.source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${this.provider.model}`,
|
||||
`${options.source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${model}`,
|
||||
);
|
||||
this.db
|
||||
.prepare(
|
||||
@@ -740,7 +769,7 @@ class MemoryManagerEmbeddingOps {
|
||||
chunk.startLine,
|
||||
chunk.endLine,
|
||||
chunk.hash,
|
||||
this.provider.model,
|
||||
model,
|
||||
chunk.text,
|
||||
JSON.stringify(embedding),
|
||||
now,
|
||||
@@ -764,7 +793,7 @@ class MemoryManagerEmbeddingOps {
|
||||
id,
|
||||
entry.path,
|
||||
options.source,
|
||||
this.provider.model,
|
||||
model,
|
||||
chunk.startLine,
|
||||
chunk.endLine,
|
||||
);
|
||||
|
||||
@@ -136,7 +136,7 @@ export function listChunks(params: {
|
||||
export async function searchKeyword(params: {
|
||||
db: DatabaseSync;
|
||||
ftsTable: string;
|
||||
providerModel: string;
|
||||
providerModel: string | undefined;
|
||||
query: string;
|
||||
limit: number;
|
||||
snippetMaxChars: number;
|
||||
@@ -152,16 +152,20 @@ export async function searchKeyword(params: {
|
||||
return [];
|
||||
}
|
||||
|
||||
// When providerModel is undefined (FTS-only mode), search all models
|
||||
const modelClause = params.providerModel ? " AND model = ?" : "";
|
||||
const modelParams = params.providerModel ? [params.providerModel] : [];
|
||||
|
||||
const rows = params.db
|
||||
.prepare(
|
||||
`SELECT id, path, source, start_line, end_line, text,\n` +
|
||||
` bm25(${params.ftsTable}) AS rank\n` +
|
||||
` FROM ${params.ftsTable}\n` +
|
||||
` WHERE ${params.ftsTable} MATCH ? AND model = ?${params.sourceFilter.sql}\n` +
|
||||
` WHERE ${params.ftsTable} MATCH ?${modelClause}${params.sourceFilter.sql}\n` +
|
||||
` ORDER BY rank ASC\n` +
|
||||
` LIMIT ?`,
|
||||
)
|
||||
.all(ftsQuery, params.providerModel, ...params.sourceFilter.params, params.limit) as Array<{
|
||||
.all(ftsQuery, ...modelParams, ...params.sourceFilter.params, params.limit) as Array<{
|
||||
id: string;
|
||||
path: string;
|
||||
source: SearchSource;
|
||||
|
||||
@@ -606,10 +606,16 @@ class MemoryManagerSyncOps {
|
||||
} catch {}
|
||||
this.db.prepare(`DELETE FROM chunks WHERE path = ? AND source = ?`).run(stale.path, "memory");
|
||||
if (this.fts.enabled && this.fts.available) {
|
||||
const deleteFtsSql =
|
||||
this.provider === null
|
||||
? `DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ?`
|
||||
: `DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ? AND model = ?`;
|
||||
try {
|
||||
const deleteFtsParams =
|
||||
this.provider === null ? [stale.path, "memory"] : [stale.path, "memory", this.provider.model];
|
||||
this.db
|
||||
.prepare(`DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ? AND model = ?`)
|
||||
.run(stale.path, "memory", this.provider.model);
|
||||
.prepare(deleteFtsSql)
|
||||
.run(...deleteFtsParams);
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
@@ -707,10 +713,18 @@ class MemoryManagerSyncOps {
|
||||
.prepare(`DELETE FROM chunks WHERE path = ? AND source = ?`)
|
||||
.run(stale.path, "sessions");
|
||||
if (this.fts.enabled && this.fts.available) {
|
||||
const deleteFtsSql =
|
||||
this.provider === null
|
||||
? `DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ?`
|
||||
: `DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ? AND model = ?`;
|
||||
try {
|
||||
const deleteFtsParams =
|
||||
this.provider === null
|
||||
? [stale.path, "sessions"]
|
||||
: [stale.path, "sessions", this.provider.model];
|
||||
this.db
|
||||
.prepare(`DELETE FROM ${FTS_TABLE} WHERE path = ? AND source = ? AND model = ?`)
|
||||
.run(stale.path, "sessions", this.provider.model);
|
||||
.prepare(deleteFtsSql)
|
||||
.run(...deleteFtsParams);
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
@@ -759,8 +773,8 @@ class MemoryManagerSyncOps {
|
||||
const needsFullReindex =
|
||||
params?.force ||
|
||||
!meta ||
|
||||
meta.model !== this.provider.model ||
|
||||
meta.provider !== this.provider.id ||
|
||||
(this.provider && meta.model !== this.provider.model) ||
|
||||
(this.provider && meta.provider !== this.provider.id) ||
|
||||
meta.providerKey !== this.providerKey ||
|
||||
meta.chunkTokens !== this.settings.chunking.tokens ||
|
||||
meta.chunkOverlap !== this.settings.chunking.overlap ||
|
||||
@@ -834,6 +848,7 @@ class MemoryManagerSyncOps {
|
||||
const batch = this.settings.remote?.batch;
|
||||
const enabled = Boolean(
|
||||
batch?.enabled &&
|
||||
this.provider &&
|
||||
((this.openAi && this.provider.id === "openai") ||
|
||||
(this.gemini && this.provider.id === "gemini") ||
|
||||
(this.voyage && this.provider.id === "voyage")),
|
||||
@@ -849,7 +864,7 @@ class MemoryManagerSyncOps {
|
||||
|
||||
private async activateFallbackProvider(reason: string): Promise<boolean> {
|
||||
const fallback = this.settings.fallback;
|
||||
if (!fallback || fallback === "none" || fallback === this.provider.id) {
|
||||
if (!fallback || fallback === "none" || !this.provider || fallback === this.provider.id) {
|
||||
return false;
|
||||
}
|
||||
if (this.fallbackFrom) {
|
||||
@@ -957,8 +972,8 @@ class MemoryManagerSyncOps {
|
||||
}
|
||||
|
||||
nextMeta = {
|
||||
model: this.provider.model,
|
||||
provider: this.provider.id,
|
||||
model: this.provider?.model ?? "fts-only",
|
||||
provider: this.provider?.id ?? "none",
|
||||
providerKey: this.providerKey,
|
||||
chunkTokens: this.settings.chunking.tokens,
|
||||
chunkOverlap: this.settings.chunking.overlap,
|
||||
@@ -1023,8 +1038,8 @@ class MemoryManagerSyncOps {
|
||||
}
|
||||
|
||||
const nextMeta: MemoryIndexMeta = {
|
||||
model: this.provider.model,
|
||||
provider: this.provider.id,
|
||||
model: this.provider?.model ?? "fts-only",
|
||||
provider: this.provider?.id ?? "none",
|
||||
providerKey: this.providerKey,
|
||||
chunkTokens: this.settings.chunking.tokens,
|
||||
chunkOverlap: this.settings.chunking.overlap,
|
||||
|
||||
@@ -28,6 +28,7 @@ import { isMemoryPath, normalizeExtraMemoryPaths } from "./internal.js";
|
||||
import { memoryManagerEmbeddingOps } from "./manager-embedding-ops.js";
|
||||
import { searchKeyword, searchVector } from "./manager-search.js";
|
||||
import { memoryManagerSyncOps } from "./manager-sync-ops.js";
|
||||
import { extractKeywords } from "./query-expansion.js";
|
||||
const SNIPPET_MAX_CHARS = 700;
|
||||
const VECTOR_TABLE = "chunks_vec";
|
||||
const FTS_TABLE = "chunks_fts";
|
||||
@@ -46,10 +47,11 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
private readonly agentId: string;
|
||||
private readonly workspaceDir: string;
|
||||
private readonly settings: ResolvedMemorySearchConfig;
|
||||
private provider: EmbeddingProvider;
|
||||
private provider: EmbeddingProvider | null;
|
||||
private readonly requestedProvider: "openai" | "local" | "gemini" | "voyage" | "auto";
|
||||
private fallbackFrom?: "openai" | "local" | "gemini" | "voyage";
|
||||
private fallbackReason?: string;
|
||||
private readonly providerUnavailableReason?: string;
|
||||
private openAi?: OpenAiEmbeddingClient;
|
||||
private gemini?: GeminiEmbeddingClient;
|
||||
private voyage?: VoyageEmbeddingClient;
|
||||
@@ -154,6 +156,7 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
this.requestedProvider = params.providerResult.requestedProvider;
|
||||
this.fallbackFrom = params.providerResult.fallbackFrom;
|
||||
this.fallbackReason = params.providerResult.fallbackReason;
|
||||
this.providerUnavailableReason = params.providerResult.providerUnavailableReason;
|
||||
this.openAi = params.providerResult.openAi;
|
||||
this.gemini = params.providerResult.gemini;
|
||||
this.voyage = params.providerResult.voyage;
|
||||
@@ -225,6 +228,42 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
Math.max(1, Math.floor(maxResults * hybrid.candidateMultiplier)),
|
||||
);
|
||||
|
||||
// FTS-only mode: no embedding provider available
|
||||
if (!this.provider) {
|
||||
if (!this.fts.enabled || !this.fts.available) {
|
||||
log.warn("memory search: no provider and FTS unavailable");
|
||||
return [];
|
||||
}
|
||||
|
||||
// Extract keywords for better FTS matching on conversational queries
|
||||
// e.g., "that thing we discussed about the API" → ["discussed", "API"]
|
||||
const keywords = extractKeywords(cleaned);
|
||||
const searchTerms = keywords.length > 0 ? keywords : [cleaned];
|
||||
|
||||
// Search with each keyword and merge results
|
||||
const resultSets = await Promise.all(
|
||||
searchTerms.map((term) => this.searchKeyword(term, candidates).catch(() => [])),
|
||||
);
|
||||
|
||||
// Merge and deduplicate results, keeping highest score for each chunk
|
||||
const seenIds = new Map<string, (typeof resultSets)[0][0]>();
|
||||
for (const results of resultSets) {
|
||||
for (const result of results) {
|
||||
const existing = seenIds.get(result.id);
|
||||
if (!existing || result.score > existing.score) {
|
||||
seenIds.set(result.id, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const merged = [...seenIds.values()]
|
||||
.toSorted((a, b) => b.score - a.score)
|
||||
.filter((entry) => entry.score >= minScore)
|
||||
.slice(0, maxResults);
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
const keywordResults = hybrid.enabled
|
||||
? await this.searchKeyword(cleaned, candidates).catch(() => [])
|
||||
: [];
|
||||
@@ -253,6 +292,10 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
queryVec: number[],
|
||||
limit: number,
|
||||
): Promise<Array<MemorySearchResult & { id: string }>> {
|
||||
// This method should never be called without a provider
|
||||
if (!this.provider) {
|
||||
return [];
|
||||
}
|
||||
const results = await searchVector({
|
||||
db: this.db,
|
||||
vectorTable: VECTOR_TABLE,
|
||||
@@ -279,10 +322,12 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
return [];
|
||||
}
|
||||
const sourceFilter = this.buildSourceFilter();
|
||||
// In FTS-only mode (no provider), search all models; otherwise filter by current provider's model
|
||||
const providerModel = this.provider?.model;
|
||||
const results = await searchKeyword({
|
||||
db: this.db,
|
||||
ftsTable: FTS_TABLE,
|
||||
providerModel: this.provider.model,
|
||||
providerModel,
|
||||
query,
|
||||
limit,
|
||||
snippetMaxChars: SNIPPET_MAX_CHARS,
|
||||
@@ -446,6 +491,13 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
}
|
||||
return sources.map((source) => Object.assign({ source }, bySource.get(source)!));
|
||||
})();
|
||||
|
||||
// Determine search mode: "fts-only" if no provider, "hybrid" otherwise
|
||||
const searchMode = this.provider ? "hybrid" : "fts-only";
|
||||
const providerInfo = this.provider
|
||||
? { provider: this.provider.id, model: this.provider.model }
|
||||
: { provider: "none", model: undefined };
|
||||
|
||||
return {
|
||||
backend: "builtin",
|
||||
files: files?.c ?? 0,
|
||||
@@ -453,8 +505,8 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
dirty: this.dirty || this.sessionsDirty,
|
||||
workspaceDir: this.workspaceDir,
|
||||
dbPath: this.settings.store.path,
|
||||
provider: this.provider.id,
|
||||
model: this.provider.model,
|
||||
provider: providerInfo.provider,
|
||||
model: providerInfo.model,
|
||||
requestedProvider: this.requestedProvider,
|
||||
sources: Array.from(this.sources),
|
||||
extraPaths: this.settings.extraPaths,
|
||||
@@ -497,10 +549,18 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
lastError: this.batchFailureLastError,
|
||||
lastProvider: this.batchFailureLastProvider,
|
||||
},
|
||||
custom: {
|
||||
searchMode,
|
||||
providerUnavailableReason: this.providerUnavailableReason,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async probeVectorAvailability(): Promise<boolean> {
|
||||
// FTS-only mode: vector search not available
|
||||
if (!this.provider) {
|
||||
return false;
|
||||
}
|
||||
if (!this.vector.enabled) {
|
||||
return false;
|
||||
}
|
||||
@@ -508,6 +568,13 @@ export class MemoryIndexManager implements MemorySearchManager {
|
||||
}
|
||||
|
||||
async probeEmbeddingAvailability(): Promise<MemoryEmbeddingProbeResult> {
|
||||
// FTS-only mode: embeddings not available but search still works
|
||||
if (!this.provider) {
|
||||
return {
|
||||
ok: false,
|
||||
error: this.providerUnavailableReason ?? "No embedding provider available (FTS-only mode)",
|
||||
};
|
||||
}
|
||||
try {
|
||||
await this.embedBatchWithRetry(["ping"]);
|
||||
return { ok: true };
|
||||
|
||||
78
src/memory/query-expansion.test.ts
Normal file
78
src/memory/query-expansion.test.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { expandQueryForFts, extractKeywords } from "./query-expansion.js";
|
||||
|
||||
describe("extractKeywords", () => {
|
||||
it("extracts keywords from English conversational query", () => {
|
||||
const keywords = extractKeywords("that thing we discussed about the API");
|
||||
expect(keywords).toContain("discussed");
|
||||
expect(keywords).toContain("api");
|
||||
// Should not include stop words
|
||||
expect(keywords).not.toContain("that");
|
||||
expect(keywords).not.toContain("thing");
|
||||
expect(keywords).not.toContain("we");
|
||||
expect(keywords).not.toContain("about");
|
||||
expect(keywords).not.toContain("the");
|
||||
});
|
||||
|
||||
it("extracts keywords from Chinese conversational query", () => {
|
||||
const keywords = extractKeywords("之前讨论的那个方案");
|
||||
expect(keywords).toContain("讨论");
|
||||
expect(keywords).toContain("方案");
|
||||
// Should not include stop words
|
||||
expect(keywords).not.toContain("之前");
|
||||
expect(keywords).not.toContain("的");
|
||||
expect(keywords).not.toContain("那个");
|
||||
});
|
||||
|
||||
it("extracts keywords from mixed language query", () => {
|
||||
const keywords = extractKeywords("昨天讨论的 API design");
|
||||
expect(keywords).toContain("讨论");
|
||||
expect(keywords).toContain("api");
|
||||
expect(keywords).toContain("design");
|
||||
});
|
||||
|
||||
it("returns specific technical terms", () => {
|
||||
const keywords = extractKeywords("what was the solution for the CFR bug");
|
||||
expect(keywords).toContain("solution");
|
||||
expect(keywords).toContain("cfr");
|
||||
expect(keywords).toContain("bug");
|
||||
});
|
||||
|
||||
it("handles empty query", () => {
|
||||
expect(extractKeywords("")).toEqual([]);
|
||||
expect(extractKeywords(" ")).toEqual([]);
|
||||
});
|
||||
|
||||
it("handles query with only stop words", () => {
|
||||
const keywords = extractKeywords("the a an is are");
|
||||
expect(keywords.length).toBe(0);
|
||||
});
|
||||
|
||||
it("removes duplicate keywords", () => {
|
||||
const keywords = extractKeywords("test test testing");
|
||||
const testCount = keywords.filter((k) => k === "test").length;
|
||||
expect(testCount).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("expandQueryForFts", () => {
|
||||
it("returns original query and extracted keywords", () => {
|
||||
const result = expandQueryForFts("that API we discussed");
|
||||
expect(result.original).toBe("that API we discussed");
|
||||
expect(result.keywords).toContain("api");
|
||||
expect(result.keywords).toContain("discussed");
|
||||
});
|
||||
|
||||
it("builds expanded OR query for FTS", () => {
|
||||
const result = expandQueryForFts("the solution for bugs");
|
||||
expect(result.expanded).toContain("OR");
|
||||
expect(result.expanded).toContain("solution");
|
||||
expect(result.expanded).toContain("bugs");
|
||||
});
|
||||
|
||||
it("returns original query when no keywords extracted", () => {
|
||||
const result = expandQueryForFts("the");
|
||||
expect(result.keywords.length).toBe(0);
|
||||
expect(result.expanded).toBe("the");
|
||||
});
|
||||
});
|
||||
357
src/memory/query-expansion.ts
Normal file
357
src/memory/query-expansion.ts
Normal file
@@ -0,0 +1,357 @@
|
||||
/**
|
||||
* Query expansion for FTS-only search mode.
|
||||
*
|
||||
* When no embedding provider is available, we fall back to FTS (full-text search).
|
||||
* FTS works best with specific keywords, but users often ask conversational queries
|
||||
* like "that thing we discussed yesterday" or "之前讨论的那个方案".
|
||||
*
|
||||
* This module extracts meaningful keywords from such queries to improve FTS results.
|
||||
*/
|
||||
|
||||
// Common stop words that don't add search value
|
||||
const STOP_WORDS_EN = new Set([
|
||||
// Articles and determiners
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
// Pronouns
|
||||
"i",
|
||||
"me",
|
||||
"my",
|
||||
"we",
|
||||
"our",
|
||||
"you",
|
||||
"your",
|
||||
"he",
|
||||
"she",
|
||||
"it",
|
||||
"they",
|
||||
"them",
|
||||
// Common verbs
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"can",
|
||||
"may",
|
||||
"might",
|
||||
// Prepositions
|
||||
"in",
|
||||
"on",
|
||||
"at",
|
||||
"to",
|
||||
"for",
|
||||
"of",
|
||||
"with",
|
||||
"by",
|
||||
"from",
|
||||
"about",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"over",
|
||||
// Conjunctions
|
||||
"and",
|
||||
"or",
|
||||
"but",
|
||||
"if",
|
||||
"then",
|
||||
"because",
|
||||
"as",
|
||||
"while",
|
||||
"when",
|
||||
"where",
|
||||
"what",
|
||||
"which",
|
||||
"who",
|
||||
"how",
|
||||
"why",
|
||||
// Time references (vague, not useful for FTS)
|
||||
"yesterday",
|
||||
"today",
|
||||
"tomorrow",
|
||||
"earlier",
|
||||
"later",
|
||||
"recently",
|
||||
"before",
|
||||
"ago",
|
||||
"just",
|
||||
"now",
|
||||
// Vague references
|
||||
"thing",
|
||||
"things",
|
||||
"stuff",
|
||||
"something",
|
||||
"anything",
|
||||
"everything",
|
||||
"nothing",
|
||||
// Question words
|
||||
"please",
|
||||
"help",
|
||||
"find",
|
||||
"show",
|
||||
"get",
|
||||
"tell",
|
||||
"give",
|
||||
]);
|
||||
|
||||
const STOP_WORDS_ZH = new Set([
|
||||
// Pronouns
|
||||
"我",
|
||||
"我们",
|
||||
"你",
|
||||
"你们",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"他们",
|
||||
"这",
|
||||
"那",
|
||||
"这个",
|
||||
"那个",
|
||||
"这些",
|
||||
"那些",
|
||||
// Auxiliary words
|
||||
"的",
|
||||
"了",
|
||||
"着",
|
||||
"过",
|
||||
"得",
|
||||
"地",
|
||||
"吗",
|
||||
"呢",
|
||||
"吧",
|
||||
"啊",
|
||||
"呀",
|
||||
"嘛",
|
||||
"啦",
|
||||
// Verbs (common, vague)
|
||||
"是",
|
||||
"有",
|
||||
"在",
|
||||
"被",
|
||||
"把",
|
||||
"给",
|
||||
"让",
|
||||
"用",
|
||||
"到",
|
||||
"去",
|
||||
"来",
|
||||
"做",
|
||||
"说",
|
||||
"看",
|
||||
"找",
|
||||
"想",
|
||||
"要",
|
||||
"能",
|
||||
"会",
|
||||
"可以",
|
||||
// Prepositions and conjunctions
|
||||
"和",
|
||||
"与",
|
||||
"或",
|
||||
"但",
|
||||
"但是",
|
||||
"因为",
|
||||
"所以",
|
||||
"如果",
|
||||
"虽然",
|
||||
"而",
|
||||
"也",
|
||||
"都",
|
||||
"就",
|
||||
"还",
|
||||
"又",
|
||||
"再",
|
||||
"才",
|
||||
"只",
|
||||
// Time (vague)
|
||||
"之前",
|
||||
"以前",
|
||||
"之后",
|
||||
"以后",
|
||||
"刚才",
|
||||
"现在",
|
||||
"昨天",
|
||||
"今天",
|
||||
"明天",
|
||||
"最近",
|
||||
// Vague references
|
||||
"东西",
|
||||
"事情",
|
||||
"事",
|
||||
"什么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"怎么",
|
||||
"为什么",
|
||||
"多少",
|
||||
// Question/request words
|
||||
"请",
|
||||
"帮",
|
||||
"帮忙",
|
||||
"告诉",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Check if a token looks like a meaningful keyword.
|
||||
* Returns false for short tokens, numbers-only, etc.
|
||||
*/
|
||||
function isValidKeyword(token: string): boolean {
|
||||
if (!token || token.length === 0) {
|
||||
return false;
|
||||
}
|
||||
// Skip very short English words (likely stop words or fragments)
|
||||
if (/^[a-zA-Z]+$/.test(token) && token.length < 3) {
|
||||
return false;
|
||||
}
|
||||
// Skip pure numbers (not useful for semantic search)
|
||||
if (/^\d+$/.test(token)) {
|
||||
return false;
|
||||
}
|
||||
// Skip tokens that are all punctuation
|
||||
if (/^[\p{P}\p{S}]+$/u.test(token)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple tokenizer that handles both English and Chinese text.
|
||||
* For Chinese, we do character-based splitting since we don't have a proper segmenter.
|
||||
* For English, we split on whitespace and punctuation.
|
||||
*/
|
||||
function tokenize(text: string): string[] {
|
||||
const tokens: string[] = [];
|
||||
const normalized = text.toLowerCase().trim();
|
||||
|
||||
// Split into segments (English words, Chinese character sequences, etc.)
|
||||
const segments = normalized.split(/[\s\p{P}]+/u).filter(Boolean);
|
||||
|
||||
for (const segment of segments) {
|
||||
// Check if segment contains CJK characters
|
||||
if (/[\u4e00-\u9fff]/.test(segment)) {
|
||||
// For Chinese, extract character n-grams (unigrams and bigrams)
|
||||
const chars = Array.from(segment).filter((c) => /[\u4e00-\u9fff]/.test(c));
|
||||
// Add individual characters
|
||||
tokens.push(...chars);
|
||||
// Add bigrams for better phrase matching
|
||||
for (let i = 0; i < chars.length - 1; i++) {
|
||||
tokens.push(chars[i] + chars[i + 1]);
|
||||
}
|
||||
} else {
|
||||
// For non-CJK, keep as single token
|
||||
tokens.push(segment);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract keywords from a conversational query for FTS search.
|
||||
*
|
||||
* Examples:
|
||||
* - "that thing we discussed about the API" → ["discussed", "API"]
|
||||
* - "之前讨论的那个方案" → ["讨论", "方案"]
|
||||
* - "what was the solution for the bug" → ["solution", "bug"]
|
||||
*/
|
||||
export function extractKeywords(query: string): string[] {
|
||||
const tokens = tokenize(query);
|
||||
const keywords: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
for (const token of tokens) {
|
||||
// Skip stop words
|
||||
if (STOP_WORDS_EN.has(token) || STOP_WORDS_ZH.has(token)) {
|
||||
continue;
|
||||
}
|
||||
// Skip invalid keywords
|
||||
if (!isValidKeyword(token)) {
|
||||
continue;
|
||||
}
|
||||
// Skip duplicates
|
||||
if (seen.has(token)) {
|
||||
continue;
|
||||
}
|
||||
seen.add(token);
|
||||
keywords.push(token);
|
||||
}
|
||||
|
||||
return keywords;
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand a query for FTS search.
|
||||
* Returns both the original query and extracted keywords for OR-matching.
|
||||
*
|
||||
* @param query - User's original query
|
||||
* @returns Object with original query and extracted keywords
|
||||
*/
|
||||
export function expandQueryForFts(query: string): {
|
||||
original: string;
|
||||
keywords: string[];
|
||||
expanded: string;
|
||||
} {
|
||||
const original = query.trim();
|
||||
const keywords = extractKeywords(original);
|
||||
|
||||
// Build expanded query: original terms OR extracted keywords
|
||||
// This ensures both exact matches and keyword matches are found
|
||||
const expanded = keywords.length > 0 ? `${original} OR ${keywords.join(" OR ")}` : original;
|
||||
|
||||
return { original, keywords, expanded };
|
||||
}
|
||||
|
||||
/**
|
||||
* Type for an optional LLM-based query expander.
|
||||
* Can be provided to enhance keyword extraction with semantic understanding.
|
||||
*/
|
||||
export type LlmQueryExpander = (query: string) => Promise<string[]>;
|
||||
|
||||
/**
|
||||
* Expand query with optional LLM assistance.
|
||||
* Falls back to local extraction if LLM is unavailable or fails.
|
||||
*/
|
||||
export async function expandQueryWithLlm(
|
||||
query: string,
|
||||
llmExpander?: LlmQueryExpander,
|
||||
): Promise<string[]> {
|
||||
// If LLM expander is provided, try it first
|
||||
if (llmExpander) {
|
||||
try {
|
||||
const llmKeywords = await llmExpander(query);
|
||||
if (llmKeywords.length > 0) {
|
||||
return llmKeywords;
|
||||
}
|
||||
} catch {
|
||||
// LLM failed, fall back to local extraction
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local keyword extraction
|
||||
return extractKeywords(query);
|
||||
}
|
||||
Reference in New Issue
Block a user