diff --git a/apps/memos-local-plugin/core/capture/ALGORITHMS.md b/apps/memos-local-plugin/core/capture/ALGORITHMS.md index 15a15ac64..49c98a105 100644 --- a/apps/memos-local-plugin/core/capture/ALGORITHMS.md +++ b/apps/memos-local-plugin/core/capture/ALGORITHMS.md @@ -78,7 +78,8 @@ priority once reward arrives. ## V7 §3.2 batched variant — `batch-scorer.ts` The per-step path (`reflection-synth.ts` + `alpha-scorer.ts`) issues 2N -LLM calls per N-step episode. `batch-scorer.ts` collapses them into ONE: +LLM calls per N-step episode. `batch-scorer.ts` collapses up to +`batchThreshold` steps into one call: ``` inputs = [{idx, state, action, outcome, reflection, synth_allowed}, …] @@ -91,8 +92,8 @@ Dispatch (in `capture.ts`): | `cfg.batchMode` | `cfg.batchThreshold` | behavior | |-------------------|----------------------|----------| | `per_step` | (ignored) | legacy: 2N calls | -| `per_episode` | (ignored) | always batch | -| `auto` (default) | `12` | batch when `N ≤ 12`; else per-step | +| `per_episode` | chunk size | batch when `N ≤ threshold`; else chunk-batch | +| `auto` (default) | `12` | batch when `N ≤ 12`; else chunk-batch | The dispatcher also refuses to batch when no LLM is wired — same fallback path as missing-LLM in per-step mode. @@ -107,15 +108,15 @@ Failure handling: - LLM throws / facade gives up after `malformedRetries=1` → capture catches in `runBatchScoring`, surfaces a `{stage: "batch"}` warning, - and the per-step path runs as a fallback. + and the per-step path runs as a fallback for that chunk. - Validator rejects on length mismatch, missing/non-numeric `alpha`, non-boolean `usable`, non-string `reflection_text`. Same fallback. Bookkeeping (`CaptureResult.llmCalls`): -- `batchedReflection`: 0 or 1 per episode (1 on a successful batch). +- `batchedReflection`: number of successful batch/chunk calls. - `reflectionSynth` / `alphaScoring`: only nonzero when the per-step path - ran (either selected directly, or as fallback after a batch failure). + ran (either selected directly, or as fallback after a chunk failure). Stable prompt fingerprint: diff --git a/apps/memos-local-plugin/core/capture/batch-scorer.ts b/apps/memos-local-plugin/core/capture/batch-scorer.ts index e7b8ab50f..da434c3b2 100644 --- a/apps/memos-local-plugin/core/capture/batch-scorer.ts +++ b/apps/memos-local-plugin/core/capture/batch-scorer.ts @@ -16,11 +16,12 @@ * `transferability` axes benefit directly. * * Trade-offs (encoded in capture.ts dispatch): - * - Prompt grows linearly with N steps. Capped via `batchThreshold`; - * long episodes degrade to the per-step path automatically. - * - One bad output value forces a single batched retry instead of N - * isolated retries — but the facade already does `malformedRetries` - * for us, and on hard failure capture.ts falls back to per-step. + * - Prompt grows linearly with N steps. Each call is capped at + * `batchThreshold`; long episodes run as several bounded chunks. + * - One bad chunk forces a single batched retry for that chunk instead + * of N isolated retries — but the facade already does + * `malformedRetries` for us, and on hard failure capture.ts falls + * back to per-step for that chunk only. * * Wire format ↔ prompt: * Send `{ host_context?, task_context?, steps: [{idx, state, action, outcome, reflection, synth_allowed}] }`. @@ -170,6 +171,7 @@ export async function batchScoreReflections( validate: (v) => validateBatchPayload(v, inputs.length), malformedRetries: 1, temperature: 0, + maxTokens: batchMaxTokens(inputs.length), }, ); @@ -321,6 +323,15 @@ function validateBatchPayload(v: unknown, expected: number): void { } } +function batchMaxTokens(stepCount: number): number { + // Batch output scales with step count; keep a per-step budget but cap below + // the 16k range that triggered avoidable reasoning spend on mimo replay. + const perStepOutputBudget = 512; + const baseBudget = 768; + const ceiling = 8_192; + return Math.min(ceiling, baseBudget + Math.max(1, stepCount) * perStepOutputBudget); +} + function lastToolOutcome(step: NormalizedStep, max: number): string { const last = step.toolCalls[step.toolCalls.length - 1]; if (!last) return "(assistant-only step)"; diff --git a/apps/memos-local-plugin/core/capture/capture.ts b/apps/memos-local-plugin/core/capture/capture.ts index 9d52f749e..8c26276ac 100644 --- a/apps/memos-local-plugin/core/capture/capture.ts +++ b/apps/memos-local-plugin/core/capture/capture.ts @@ -435,14 +435,14 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { } // Batch reflection + α across every step of the now-closed - // episode. Falls back to per-step scoring when over the threshold - // or when batching fails / no LLM is wired. The reflect pass uses + // episode. Long episodes are chunk-batched at `batchThreshold`; + // failed chunks fall back to per-step scoring. The reflect pass uses // `reflectLlm` (skill-evolver model when configured) for higher // quality reflections; per-turn lite capture still uses `llm`. const reflectStart = now(); const rLlm = deps.reflectLlm ?? deps.llm; - const useBatch = shouldBatch(deps.cfg, normalized.length, rLlm !== null); - const contextEnabled = contextModeFor(deps.cfg, useBatch, normalized.length); + const scoringPlan = planScoring(deps.cfg, normalized.length, rLlm !== null); + const contextEnabled = contextModeFor(deps.cfg, scoringPlan, normalized.length); const taskSummary = contextEnabled.includeTask ? buildTaskReflectionSummary(input.episode, normalized, deps.cfg.taskContextMaxChars) : null; @@ -453,7 +453,10 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { episodeId: input.episode.id, sessionId: input.episode.sessionId, steps: normalized.length, - mode: useBatch ? "batch" : contextEnabled.includeDownstream ? "per_step_downstream" : "per_step", + mode: scoringPlan === "per_step" && contextEnabled.includeDownstream ? "per_step_downstream" : scoringPlan, + chunks: scoringPlan === "chunk_batch" + ? Math.ceil(normalized.length / Math.max(1, deps.cfg.batchThreshold)) + : undefined, reflectionContextMode: deps.cfg.reflectionContextMode, downstreamPreview: contextEnabled.includeDownstream, provider: rLlm?.provider ?? "none", @@ -461,10 +464,13 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { taskSummary: taskSummary ? taskSummary.slice(0, 240) : null, }); let scored: ScoredStep[] = []; - if (useBatch) { + if (scoringPlan === "batch") { scored = await runBatchScoring(normalized, rLlm!, deps, warnings, llmCalls, input.episode.id, taskSummary); } - if (!useBatch || scored.length === 0) { + if (scoringPlan === "chunk_batch") { + scored = await runChunkedBatchScoring(normalized, rLlm!, deps, warnings, llmCalls, input.episode.id, taskSummary); + } + if (scoringPlan === "per_step" || scored.length === 0) { scored = await runPerStepScoring( normalized, rLlm, @@ -1018,30 +1024,30 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { // ─── helpers ──────────────────────────────────────────────────────────────── /** - * Decide whether to use the batched reflection+α path. + * Decide which reflection+α path to use. * * `per_step` → never (legacy path). - * `per_episode` → always, when an LLM is available. - * `auto` → batch when step count fits inside `batchThreshold`. + * `per_episode` → batch up to threshold, then chunk-batch. + * `auto` → batch up to threshold, then chunk-batch. */ -function shouldBatch(cfg: CaptureConfig, stepCount: number, hasLlm: boolean): boolean { - if (!hasLlm) return false; - if (stepCount === 0) return false; - if (cfg.batchMode === "per_step") return false; - if (cfg.batchMode === "per_episode") return true; - // "auto" - return stepCount <= cfg.batchThreshold; +type ScoringPlan = "per_step" | "batch" | "chunk_batch"; + +function planScoring(cfg: CaptureConfig, stepCount: number, hasLlm: boolean): ScoringPlan { + if (!hasLlm) return "per_step"; + if (stepCount === 0) return "per_step"; + if (cfg.batchMode === "per_step") return "per_step"; + return stepCount <= Math.max(1, cfg.batchThreshold) ? "batch" : "chunk_batch"; } function contextModeFor( cfg: CaptureConfig, - useBatch: boolean, + scoringPlan: ScoringPlan, stepCount: number, ): { includeTask: boolean; includeDownstream: boolean } { const mode = cfg.reflectionContextMode; const includeTask = mode === "task" || mode === "task_downstream"; const wantsDownstream = mode === "downstream" || mode === "task_downstream"; - const longPerStep = !useBatch && stepCount > cfg.batchThreshold; + const longPerStep = scoringPlan === "per_step" && stepCount > cfg.batchThreshold; const includeDownstream = wantsDownstream && cfg.longEpisodeReflectMode === "per_step_downstream" && @@ -1101,6 +1107,37 @@ async function runBatchScoring( } } +async function runChunkedBatchScoring( + normalized: NormalizedStep[], + llm: LlmClient, + deps: CaptureDeps, + warnings: CaptureResult["warnings"], + llmCalls: { reflectionSynth: number; alphaScoring: number; batchedReflection: number }, + episodeId: string, + taskSummary: string | null, +): Promise { + const chunkSize = Math.max(1, deps.cfg.batchThreshold); + const chunks: NormalizedStep[][] = []; + for (let start = 0; start < normalized.length; start += chunkSize) { + chunks.push(normalized.slice(start, start + chunkSize)); + } + const concurrency = Math.max(1, deps.cfg.llmConcurrency); + const scoredChunks = await runConcurrently(chunks, concurrency, async (chunk): Promise => { + const scored = await runBatchScoring(chunk, llm, deps, warnings, llmCalls, episodeId, taskSummary); + if (scored.length > 0) return scored; + return runPerStepScoring( + chunk, + llm, + deps, + warnings, + llmCalls, + episodeId, + buildReflectionContexts(chunk, taskSummary, chunk.map(() => [])), + ); + }); + return scoredChunks.flat(); +} + async function runPerStepScoring( normalized: NormalizedStep[], llm: LlmClient | null, diff --git a/apps/memos-local-plugin/core/llm/client.ts b/apps/memos-local-plugin/core/llm/client.ts index 7749a572f..d969d714d 100644 --- a/apps/memos-local-plugin/core/llm/client.ts +++ b/apps/memos-local-plugin/core/llm/client.ts @@ -71,6 +71,124 @@ export function createLlmClientWithProvider( let lastFallbackAt: number | null = null; let lastError: { at: number; message: string } | null = null; + // ─── Circuit breaker state (issue #1897) ───────────────────────────────── + // Per-client breaker that trips on terminal provider errors (401/402/403, + // "insufficient balance", "invalid api key", "unauthorized", "account + // suspended", "billing"). Short-circuits subsequent calls inside the + // facade so the broken provider is not contacted again until cool-down + // elapses. Half-open: the next call after `circuitOpenUntil` probes the + // provider; success closes the breaker, terminal failure re-opens it. + const breakerCfg = config.circuitBreaker ?? {}; + const breakerEnabled = breakerCfg.enabled !== false; + const breakerCooldownMs = Math.max(30_000, breakerCfg.cooldownMs ?? 300_000); + const breakerIsTerminal = breakerCfg.isTerminal ?? defaultIsTerminal; + const breakerNow = breakerCfg.now ?? Date.now; + let circuitOpenUntil: number | null = null; + let circuitOpenedReason: string | null = null; + let lastCircuitOpenStatusAt: number | null = null; + + function breakerIsOpen(): boolean { + if (!breakerEnabled) return false; + if (circuitOpenUntil === null) return false; + if (breakerNow() >= circuitOpenUntil) { + // Cool-down elapsed → transition to half-open. We do NOT clear + // `circuitOpenUntil` yet so the very first probe attempt that + // races with the cool-down boundary doesn't fall through to "no + // breaker" twice. The next call's success/failure handler resets + // or re-opens the breaker explicitly. + return false; + } + return true; + } + + function breakerTrip(err: unknown): void { + if (!breakerEnabled) return; + circuitOpenUntil = breakerNow() + breakerCooldownMs; + circuitOpenedReason = summarizeErrMessage(err); + // Reset the coalescer so the first suppressed call after a fresh + // trip always emits a `circuit_open` row. + lastCircuitOpenStatusAt = null; + facadeLog.warn("circuit_breaker.trip", { + provider: provider.name, + model: config.model, + until: circuitOpenUntil, + reason: circuitOpenedReason, + }); + } + + function breakerRecordSuccess(): void { + if (!breakerEnabled) return; + if (circuitOpenUntil !== null) { + facadeLog.info("circuit_breaker.close", { + provider: provider.name, + model: config.model, + }); + } + circuitOpenUntil = null; + circuitOpenedReason = null; + lastCircuitOpenStatusAt = null; + } + + /** + * Emit a coalesced `circuit_open` audit row. At most one row per + * `cooldownMs/12` window per client — bounds audit-row spam while + * still surfacing the suppressed-call event in the Logs viewer. + * The first suppressed call after a fresh trip always emits. + */ + function maybeEmitCircuitOpenStatus(opts: LlmCallOptions | undefined, op: string): void { + if (!config.onStatus) return; + const at = breakerNow(); + const coalesceWindow = Math.max(5_000, Math.floor(breakerCooldownMs / 12)); + if ( + lastCircuitOpenStatusAt !== null && + at - lastCircuitOpenStatusAt < coalesceWindow + ) { + return; + } + lastCircuitOpenStatusAt = at; + try { + config.onStatus({ + status: "circuit_open", + provider: provider.name, + model: config.model, + message: circuitOpenedReason ?? "(unknown reason)", + at, + durationMs: 0, + op, + episodeId: opts?.episodeId, + phase: opts?.phase, + }); + } catch { + /* status sink errors are non-fatal */ + } + } + + function throwBreakerOpen(): never { + throw makeBreakerOpenError(); + } + + function makeBreakerOpenError(): MemosError { + const until = circuitOpenUntil ?? breakerNow(); + return new MemosError( + ERROR_CODES.LLM_UNAVAILABLE, + `circuit_open: ${circuitOpenedReason ?? "terminal provider error"}`, + { + circuitOpen: true, + until, + provider: provider.name, + model: config.model, + }, + ); + } + + function canUseHostFallback(): boolean { + return ( + config.fallbackToHost === true && + provider.name !== "host" && + getHostLlmBridge() !== null + ); + } + /** * Mark a successful primary-provider call. We **do not** clear * `lastError` / `lastFallbackAt` here — the viewer picks the most @@ -151,6 +269,21 @@ export function createLlmClientWithProvider( opts: LlmCallOptions | undefined, op: string, ): Promise<{ completion: LlmCompletion }> { + // ── Circuit breaker short-circuit ── + // When the breaker is open we never reach the primary provider, so + // no request is generated against the broken paid API. We still + // emit (coalesced) `circuit_open` status rows so the Logs viewer / + // Overview can surface that suppression is happening. + if (breakerIsOpen()) { + maybeEmitCircuitOpenStatus(opts, op); + if (canUseHostFallback()) { + return callHostFallback(makeBreakerOpenError(), messages, input, opts, op, { + keepBreakerOpen: true, + notifyError: false, + }); + } + throwBreakerOpen(); + } requests++; const startedAt = Date.now(); try { @@ -166,6 +299,7 @@ export function createLlmClientWithProvider( }; record(completion, op, messages); const okAt = markOk(); + breakerRecordSuccess(); notifyStatus({ status: "ok", provider: provider.name, @@ -179,45 +313,13 @@ export function createLlmClientWithProvider( return { completion }; } catch (err) { if (shouldFallback(err, config, provider.name)) { - const hostProv = new HostLlmProvider(); + const primaryTerminal = breakerIsTerminal(err); + if (primaryTerminal) breakerTrip(err); try { - const res = await hostProv.complete(messages, input, makeCtx(opts, asProviderLog(rootLogger.child({ channel: "llm.host" })))); - hostFallbacks++; - facadeLog.warn("host.fallback", { - from: provider.name, - op, - reason: summarizeErr(err), + return await callHostFallback(err, messages, input, opts, op, { + keepBreakerOpen: primaryTerminal, + notifyError: true, }); - const completion: LlmCompletion = { - text: res.text, - provider: provider.name, - model: config.model, - finishReason: res.finishReason, - usage: res.usage, - servedBy: "host_fallback", - durationMs: res.durationMs, - }; - record(completion, op, messages); - // The primary provider is still broken even though the host - // bridge saved this call. Tag the slot yellow (`lastFallbackAt`) - // and surface the upstream error to the user via the - // system_error log so they can see *why* fallback engaged. - const fallbackAt = markFallback(err); - notifyOnError(err); - notifyStatus({ - status: "fallback", - provider: provider.name, - model: config.model, - message: summarizeErrMessage(err), - code: err instanceof MemosError ? err.code : undefined, - at: fallbackAt, - durationMs: completion.durationMs, - fallbackProvider: "host", - op, - episodeId: opts?.episodeId, - phase: opts?.phase, - }); - return { completion }; } catch (hostErr) { failures++; const failAt = markFail(hostErr); @@ -225,6 +327,10 @@ export function createLlmClientWithProvider( primary: summarizeErr(err), host: summarizeErr(hostErr), }); + // Primary AND host bridge both failed. Trip on a terminal + // primary error (the one the operator typically needs to fix + // — host bridge failures are usually transient stdio issues). + if (breakerIsTerminal(err)) breakerTrip(err); notifyOnError(hostErr); notifyStatus({ status: "error", @@ -249,6 +355,7 @@ export function createLlmClientWithProvider( } failures++; const failAt = markFail(err); + if (breakerIsTerminal(err)) breakerTrip(err); notifyOnError(err); notifyStatus({ status: "error", @@ -415,6 +522,12 @@ export function createLlmClientWithProvider( const call = buildCallInput(opts, opts?.jsonMode === true); const ctx = makeCtx(opts, asProviderLog(providerLog)); + // Short-circuit stream calls when the breaker is open. We do not + // count a suppressed call against `requests` (no network hit). + if (breakerIsOpen()) { + maybeEmitCircuitOpenStatus(opts, opts?.op ?? "stream"); + throwBreakerOpen(); + } requests++; const start = Date.now(); let acc = ""; @@ -448,6 +561,7 @@ export function createLlmClientWithProvider( if (usage?.promptTokens) totalPromptTokens += usage.promptTokens; if (usage?.completionTokens) totalCompletionTokens += usage.completionTokens; const okAt = markOk(); + breakerRecordSuccess(); notifyStatus({ status: "ok", provider: provider.name, @@ -461,6 +575,7 @@ export function createLlmClientWithProvider( } catch (err) { failures++; const failAt = markFail(err); + if (breakerIsTerminal(err)) breakerTrip(err); facadeLog.error("stream.failed", { err: summarizeErr(err) }); notifyOnError(err); notifyStatus({ @@ -479,6 +594,59 @@ export function createLlmClientWithProvider( } } + async function callHostFallback( + primaryErr: unknown, + messages: LlmMessage[], + input: ProviderCallInput, + opts: LlmCallOptions | undefined, + op: string, + behavior: { keepBreakerOpen: boolean; notifyError: boolean }, + ): Promise<{ completion: LlmCompletion }> { + const hostProv = new HostLlmProvider(); + const res = await hostProv.complete( + messages, + input, + makeCtx(opts, asProviderLog(rootLogger.child({ channel: "llm.host" }))), + ); + hostFallbacks++; + facadeLog.warn("host.fallback", { + from: provider.name, + op, + reason: summarizeErr(primaryErr), + }); + const completion: LlmCompletion = { + text: res.text, + provider: provider.name, + model: config.model, + finishReason: res.finishReason, + usage: res.usage, + servedBy: "host_fallback", + durationMs: res.durationMs, + }; + record(completion, op, messages); + // The primary provider is still broken even though the host bridge + // saved this call. Keep the breaker open for terminal primary + // errors so later calls can go straight to host fallback without + // touching the paid provider again. + const fallbackAt = markFallback(primaryErr); + if (!behavior.keepBreakerOpen) breakerRecordSuccess(); + if (behavior.notifyError) notifyOnError(primaryErr); + notifyStatus({ + status: "fallback", + provider: provider.name, + model: config.model, + message: summarizeErrMessage(primaryErr), + code: primaryErr instanceof MemosError ? primaryErr.code : undefined, + at: fallbackAt, + durationMs: completion.durationMs, + fallbackProvider: "host", + op, + episodeId: opts?.episodeId, + phase: opts?.phase, + }); + return { completion }; + } + const client: LlmClient = { provider: provider.name, model: config.model, @@ -497,6 +665,9 @@ export function createLlmClientWithProvider( lastOkAt, lastFallbackAt, lastError, + circuitOpen: breakerIsOpen(), + circuitOpenUntil, + circuitOpenedReason, }; }, resetStats(): void { @@ -509,6 +680,9 @@ export function createLlmClientWithProvider( lastOkAt = null; lastFallbackAt = null; lastError = null; + circuitOpenUntil = null; + circuitOpenedReason = null; + lastCircuitOpenStatusAt = null; }, async close(): Promise { await provider.close?.(); @@ -522,6 +696,10 @@ export function createLlmClientWithProvider( timeoutMs: config.timeoutMs, maxRetries: config.maxRetries, fallbackToHost: config.fallbackToHost, + circuitBreaker: { + enabled: breakerEnabled, + cooldownMs: breakerCooldownMs, + }, }); return client; @@ -562,6 +740,40 @@ function shouldFallback(err: unknown, config: LlmConfig, providerName: LlmProvid ); } +/** + * Default circuit-breaker classifier for terminal provider errors. + * + * A "terminal" error is one that will keep failing until the operator + * intervenes (top up balance, fix API key, fix model name). Retrying + * such an error just burns paid quota and pollutes the audit log, so + * the breaker opens and short-circuits further calls for the cool- + * down window. Issue #1897 reports the symptom — ~12,900 paid LLM + * requests in 24 h against a key with insufficient balance. + * + * Detection sources, in order: + * 1. `MemosError(LLM_UNAVAILABLE)` with `details.status` ∈ 401/402/403 + * — set by `core/llm/fetcher.ts::httpPostJson` for non-ok HTTP + * responses. + * 2. Well-known lowercase phrases in the error message (so providers + * that return 400 for "Insufficient Balance" — looking at you, + * DeepSeek — are still recognized). + */ +function defaultIsTerminal(err: unknown): boolean { + if (!(err instanceof MemosError)) return false; + if (err.code !== ERROR_CODES.LLM_UNAVAILABLE) return false; + const status = Number((err.details as { status?: unknown } | undefined)?.status); + if (status === 401 || status === 402 || status === 403) return true; + const msg = (err.message ?? "").toLowerCase(); + return ( + msg.includes("insufficient balance") || + msg.includes("invalid api key") || + msg.includes("invalid_api_key") || + msg.includes("unauthorized") || + msg.includes("account suspended") || + msg.includes("billing") + ); +} + // ─── Logger adapter ────────────────────────────────────────────────────────── function asProviderLog(log: Logger): LlmProviderLogger { diff --git a/apps/memos-local-plugin/core/llm/index.ts b/apps/memos-local-plugin/core/llm/index.ts index 847c965f0..a295cd4a4 100644 --- a/apps/memos-local-plugin/core/llm/index.ts +++ b/apps/memos-local-plugin/core/llm/index.ts @@ -28,6 +28,7 @@ export { LocalOnlyLlmProvider } from "./providers/local-only.js"; export * from "./prompts/index.js"; export type { LlmCallOptions, + LlmCircuitBreakerConfig, LlmCompleteJsonOptions, LlmCompletion, LlmClient, @@ -40,6 +41,7 @@ export type { LlmProviderLogger, LlmProviderName, LlmRole, + LlmStatusDetail, LlmStreamChunk, LlmUsage, ProviderCallInput, diff --git a/apps/memos-local-plugin/core/llm/types.ts b/apps/memos-local-plugin/core/llm/types.ts index ddfe80c1e..4a835caee 100644 --- a/apps/memos-local-plugin/core/llm/types.ts +++ b/apps/memos-local-plugin/core/llm/types.ts @@ -47,6 +47,33 @@ export interface LlmConfig { * daemon can display status produced by a separate stdio bridge. */ onStatus?: (detail: LlmStatusDetail) => void; + /** + * Optional circuit breaker config. The breaker trips on terminal + * provider errors (HTTP 401/402/403, or well-known phrases like + * "insufficient balance" / "invalid api key" / "unauthorized" / + * "account suspended" / "billing") and short-circuits subsequent + * calls for a cool-down window. Defaults to enabled. See + * `apps/memos-local-plugin/openspec/changes/.../design.md` + * (issue #1897) for the full state machine. + */ + circuitBreaker?: LlmCircuitBreakerConfig; +} + +export interface LlmCircuitBreakerConfig { + /** Default true. Set false to restore legacy (no-breaker) behavior. */ + enabled?: boolean; + /** + * Cool-down window before the breaker enters half-open. Default + * 300_000 ms (5 minutes); minimum clamped to 30_000 ms. + */ + cooldownMs?: number; + /** + * Override the default classifier. Returns true if the error should + * trip the breaker (terminal / non-recoverable). + */ + isTerminal?: (err: unknown) => boolean; + /** Injected clock for tests. Default `Date.now`. */ + now?: () => number; } export interface LlmErrorDetail { @@ -67,7 +94,7 @@ export interface LlmErrorDetail { } export interface LlmStatusDetail { - status: "ok" | "fallback" | "error"; + status: "ok" | "fallback" | "error" | "circuit_open"; provider: LlmProviderName | string; model: string; message?: string; @@ -260,6 +287,17 @@ export interface LlmClientStats extends LastCallStatus { retries: number; totalPromptTokens: number; totalCompletionTokens: number; + /** + * True while the per-client circuit breaker is open (and any + * cooldown timer has not yet elapsed). When true, further calls are + * short-circuited inside the facade and throw immediately without + * touching the provider. See issue #1897. + */ + circuitOpen: boolean; + /** Epoch ms at which the open breaker becomes eligible for half-open probe. */ + circuitOpenUntil: number | null; + /** Free-text reason from the error that opened the breaker. */ + circuitOpenedReason: string | null; } export interface LlmClient { diff --git a/apps/memos-local-plugin/tests/helpers/fake-llm.ts b/apps/memos-local-plugin/tests/helpers/fake-llm.ts index 22d9fc1a3..ec2c00261 100644 --- a/apps/memos-local-plugin/tests/helpers/fake-llm.ts +++ b/apps/memos-local-plugin/tests/helpers/fake-llm.ts @@ -20,7 +20,7 @@ export interface FakeLlmScript { complete?: Record string | Promise)>; completeJson?: Record< string, - unknown | ((input: unknown) => unknown | Promise) + unknown | ((input: unknown, opts?: unknown) => unknown | Promise) >; /** Override the served-by identifier. */ servedBy?: LlmProviderName | "host_fallback"; @@ -64,7 +64,7 @@ export function fakeLlm(script: FakeLlmScript = {}): LlmClient { throw new Error(`fakeLlm: no completeJson mock for op="${op}"`); } const value = (typeof entry === "function" - ? await (entry as (x: unknown) => unknown)(input) + ? await (entry as (x: unknown, o?: unknown) => unknown)(input, opts) : entry) as T; if (o?.validate) o.validate(value); return { diff --git a/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts b/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts index d86290517..bc4b76f28 100644 --- a/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts +++ b/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts @@ -7,14 +7,15 @@ * 2. existing reflections are preserved verbatim; * 3. synth-disabled steps stay at α=0 even when the LLM tries to write * one for them; - * 4. `auto` mode falls back to per-step when stepCount > batchThreshold; - * 5. a malformed batched response degrades into the per-step path - * instead of crashing capture. + * 4. `auto` mode chunk-batches when stepCount > batchThreshold; + * 5. a malformed chunk degrades only that chunk into the per-step path + * instead of dropping the whole episode to per-step. */ import { afterEach, beforeAll, beforeEach, describe, expect, it } from "vitest"; import { createCaptureRunner, type CaptureRunner } from "../../../core/capture/capture.js"; +import { batchScoreReflections } from "../../../core/capture/batch-scorer.js"; import { createCaptureEventBus } from "../../../core/capture/events.js"; import { BATCH_REFLECTION_PROMPT, @@ -312,20 +313,31 @@ describe("capture/pipeline (batched ρ+α path)", () => { expect(t.alpha).toBe(0); // V7 disabledScore semantics }); - it("auto mode falls back to per-step when stepCount > batchThreshold", async () => { + it("auto mode chunk-batches when stepCount > batchThreshold", async () => { + const batchStates: string[][] = []; const llm = fakeLlm({ completeJson: { - // ONLY per-step alpha mock; if batched gets called, the test fails - // with "no completeJson mock for op=...batch...". - [alphaOp]: { alpha: 0.5, usable: true, reason: "ok" }, - }, - complete: { - "capture.reflection.synth": "I made this decision deliberately.", + [batchOp]: (input) => { + const messages = input as Array<{ role: string; content: string }>; + const payload = JSON.parse(messages[messages.length - 1]!.content) as { + steps: Array<{ idx: number; state: string }>; + }; + batchStates.push(payload.steps.map((s) => s.state)); + return { + scores: payload.steps.map((step) => ({ + idx: step.idx, + reflection_text: `reflection ${step.state}`, + alpha: step.idx === 0 ? 0.2 : 0.4, + usable: true, + reason: "ok", + })), + }; + }, }, }); const runner = buildRunner({ batchMode: "auto", batchThreshold: 2 }, llm); - // 3 steps → above threshold → per-step path. + // 3 steps → above threshold → two bounded batch chunks. const ep = episodeSnapshot({ id: "ep_1", sessionId: "se_1", @@ -341,10 +353,17 @@ describe("capture/pipeline (batched ρ+α path)", () => { const result = await runCapture(runner, ep); expect(result.traceIds).toHaveLength(3); - expect(result.llmCalls.batchedReflection).toBe(0); - // 3 synth + 3 alpha calls in per-step mode. - expect(result.llmCalls.reflectionSynth).toBe(3); - expect(result.llmCalls.alphaScoring).toBe(3); + expect(batchStates).toEqual([["a", "b"], ["c"]]); + expect(result.llmCalls.batchedReflection).toBe(2); + expect(result.llmCalls.reflectionSynth).toBe(0); + expect(result.llmCalls.alphaScoring).toBe(0); + + const rows = result.traceIds.map((id) => tmp.repos.traces.getById(id)!); + expect(rows.map((row) => row.reflection)).toEqual([ + "reflection a", + "reflection b", + "reflection c", + ]); }); it("long per-step downstream mode injects up to three following steps", async () => { @@ -368,7 +387,7 @@ describe("capture/pipeline (batched ρ+α path)", () => { }); const runner = buildRunner( { - batchMode: "auto", + batchMode: "per_step", batchThreshold: 2, reflectionContextMode: "task_downstream", longEpisodeReflectMode: "per_step_downstream", @@ -415,16 +434,27 @@ describe("capture/pipeline (batched ρ+α path)", () => { expect(step3Prompt).not.toContain("[step+3]"); }); - it("per_episode mode batches even when step count is large", async () => { - const scores = Array.from({ length: 5 }, (_, i) => ({ - idx: i, - reflection_text: `reflection #${i}`, - alpha: 0.4, - usable: true, - reason: "ok", - })); + it("per_episode mode chunk-batches when step count is large", async () => { + const chunkSizes: number[] = []; const llm = fakeLlm({ - completeJson: { [batchOp]: { scores } }, + completeJson: { + [batchOp]: (input) => { + const messages = input as Array<{ role: string; content: string }>; + const payload = JSON.parse(messages[messages.length - 1]!.content) as { + steps: Array<{ idx: number; state: string }>; + }; + chunkSizes.push(payload.steps.length); + return { + scores: payload.steps.map((step) => ({ + idx: step.idx, + reflection_text: `reflection ${step.state}`, + alpha: 0.4, + usable: true, + reason: "ok", + })), + }; + }, + }, }); const runner = buildRunner({ batchMode: "per_episode", batchThreshold: 2 }, llm); @@ -436,10 +466,111 @@ describe("capture/pipeline (batched ρ+α path)", () => { const ep = episodeSnapshot({ id: "ep_1", sessionId: "se_1", turns }); const result = await runCapture(runner, ep); expect(result.traceIds).toHaveLength(5); - expect(result.llmCalls.batchedReflection).toBe(1); + expect(chunkSizes).toEqual([2, 2, 1]); + expect(result.llmCalls.batchedReflection).toBe(3); expect(result.llmCalls.alphaScoring).toBe(0); }); + it("chunk-batch falls back to per-step only for the failed chunk", async () => { + const llm = fakeLlm({ + completeJson: { + [batchOp]: (input) => { + const messages = input as Array<{ role: string; content: string }>; + const payload = JSON.parse(messages[messages.length - 1]!.content) as { + steps: Array<{ idx: number; state: string }>; + }; + if (payload.steps[0]?.state === "q2") { + throw new Error("chunk failed"); + } + return { + scores: payload.steps.map((step) => ({ + idx: step.idx, + reflection_text: `batch ${step.state}`, + alpha: step.state === "q4" ? 0.5 : 0.2, + usable: true, + reason: "ok", + })), + }; + }, + [alphaOp]: { alpha: 0.9, usable: true, reason: "fallback" }, + }, + complete: { + "capture.reflection.synth": "per-step fallback reflection", + }, + }); + const runner = buildRunner({ batchMode: "auto", batchThreshold: 2 }, llm); + + const turns: EpisodeTurn[] = []; + for (let i = 0; i < 5; i++) { + turns.push(turn("user", `q${i}`, 1_000 + i * 100)); + turns.push(turn("assistant", `a${i}`, 1_050 + i * 100)); + } + const ep = episodeSnapshot({ id: "ep_1", sessionId: "se_1", turns }); + + const result = await runCapture(runner, ep); + expect(result.traceIds).toHaveLength(5); + expect(result.llmCalls.batchedReflection).toBe(2); + expect(result.llmCalls.reflectionSynth).toBe(2); + expect(result.llmCalls.alphaScoring).toBe(2); + expect(result.warnings.filter((w) => w.stage === "batch")).toHaveLength(1); + + const rows = result.traceIds.map((id) => tmp.repos.traces.getById(id)!); + expect(rows.map((row) => row.reflection)).toEqual([ + "batch q0", + "batch q1", + "per-step fallback reflection", + "per-step fallback reflection", + "batch q4", + ]); + expect(rows.map((row) => row.alpha)).toEqual([0.2, 0.2, 0.9, 0.9, 0.5]); + }); + + it("batch scorer passes an explicit maxTokens budget", async () => { + let seenMaxTokens: number | undefined; + const llm = fakeLlm({ + completeJson: { + [batchOp]: (_input, opts) => { + seenMaxTokens = (opts as { maxTokens?: number }).maxTokens; + return { + scores: [ + { + idx: 0, + reflection_text: "I made a useful choice.", + alpha: 0.5, + usable: true, + reason: "ok", + }, + ], + }; + }, + }, + }); + + await batchScoreReflections( + llm, + [ + { + step: { + key: "s1", + ts: 1_000 as EpochMs, + type: "text", + userText: "q", + agentText: "a", + agentThinking: null, + toolCalls: [], + rawReflection: null, + meta: {}, + }, + existingReflection: null, + }, + ], + { synthReflections: true }, + ); + + expect(seenMaxTokens).toBeGreaterThan(0); + expect(seenMaxTokens).toBeLessThan(16_384); + }); + it("malformed batched response → falls back to per-step + emits warning", async () => { const llm = fakeLlm({ completeJson: { diff --git a/apps/memos-local-plugin/tests/unit/llm/client.test.ts b/apps/memos-local-plugin/tests/unit/llm/client.test.ts index cd5cf6106..7e904a2c6 100644 --- a/apps/memos-local-plugin/tests/unit/llm/client.test.ts +++ b/apps/memos-local-plugin/tests/unit/llm/client.test.ts @@ -14,6 +14,7 @@ import type { LlmProvider, LlmProviderCtx, LlmProviderName, + LlmStatusDetail, LlmStreamChunk, ProviderCallInput, ProviderCompletion, @@ -277,4 +278,241 @@ describe("llm/client", () => { const client = createLlmClientWithProvider(cfg(), fake); await expect(client.complete([] as LlmMessage[])).rejects.toBeInstanceOf(MemosError); }); + + // ─── Circuit breaker (issue #1897) ────────────────────────────────────── + describe("circuit breaker", () => { + function statusSink(): { rows: LlmStatusDetail[]; push: (d: LlmStatusDetail) => void } { + const rows: LlmStatusDetail[] = []; + return { rows, push: (d) => rows.push(d) }; + } + + it("trips on terminal 402 and short-circuits subsequent calls", async () => { + const sink = statusSink(); + let now = 1_000_000; + const tick = () => now; + const provider = new ThrowingProvider( + new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "HTTP 402 from openai_compatible", { + provider: "openai_compatible", + status: 402, + }), + ); + const client = createLlmClientWithProvider( + cfg({ + onStatus: sink.push, + circuitBreaker: { enabled: true, cooldownMs: 300_000, now: tick }, + }), + provider, + ); + // First call: real provider hit, fails terminally → breaker trips. + await expect(client.complete("first")).rejects.toBeInstanceOf(MemosError); + expect(provider.calls).toBe(1); + // Second call: should be short-circuited; provider must NOT be invoked. + now += 100; + await expect(client.complete("second")).rejects.toMatchObject({ + code: ERROR_CODES.LLM_UNAVAILABLE, + details: { circuitOpen: true }, + }); + expect(provider.calls).toBe(1); + // Stats expose circuit state. + const stats = client.stats(); + expect(stats.circuitOpen).toBe(true); + expect(stats.circuitOpenUntil).toBe(1_000_000 + 300_000); + expect(stats.circuitOpenedReason).toMatch(/402/); + // Audit rows: at least one `error` and one `circuit_open`. + const statuses = sink.rows.map((r) => r.status); + expect(statuses).toContain("error"); + expect(statuses).toContain("circuit_open"); + }); + + it("trips on 'insufficient balance' message regardless of HTTP status", async () => { + const sink = statusSink(); + const provider = new ThrowingProvider( + new MemosError( + ERROR_CODES.LLM_UNAVAILABLE, + "HTTP 400 from openai_compatible: Insufficient Balance", + { provider: "openai_compatible", status: 400 }, + ), + ); + const client = createLlmClientWithProvider( + cfg({ onStatus: sink.push, circuitBreaker: { enabled: true } }), + provider, + ); + await expect(client.complete("x")).rejects.toBeInstanceOf(MemosError); + await expect(client.complete("y")).rejects.toMatchObject({ + details: { circuitOpen: true }, + }); + expect(provider.calls).toBe(1); + }); + + it("does NOT trip on generic LLM_UNAVAILABLE without terminal markers", async () => { + const sink = statusSink(); + const provider = new ThrowingProvider( + new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "transient network blip"), + ); + const client = createLlmClientWithProvider( + cfg({ onStatus: sink.push, circuitBreaker: { enabled: true } }), + provider, + ); + // Two consecutive failures with non-terminal classification → both + // calls reach the provider, breaker stays closed. + await expect(client.complete("x")).rejects.toBeInstanceOf(MemosError); + await expect(client.complete("y")).rejects.toBeInstanceOf(MemosError); + expect(provider.calls).toBe(2); + expect(client.stats().circuitOpen).toBe(false); + }); + + it("coalesces circuit_open status rows within cooldown", async () => { + const sink = statusSink(); + let now = 1_000_000; + const tick = () => now; + const provider = new ThrowingProvider( + new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "401", { status: 401 }), + ); + const client = createLlmClientWithProvider( + cfg({ + onStatus: sink.push, + circuitBreaker: { enabled: true, cooldownMs: 300_000, now: tick }, + }), + provider, + ); + await expect(client.complete("trip")).rejects.toBeTruthy(); + // 20 suppressed calls within 1 second → at most a small number of + // `circuit_open` rows (we expect 1, but tolerate up to 2 in case the + // coalescer counts the very first short-circuit as a separate row). + for (let i = 0; i < 20; i++) { + now += 50; + await expect(client.complete(`spam-${i}`)).rejects.toBeTruthy(); + } + const openRows = sink.rows.filter((r) => r.status === "circuit_open"); + expect(openRows.length).toBeGreaterThanOrEqual(1); + expect(openRows.length).toBeLessThanOrEqual(2); + // Provider was only touched once (the very first call that tripped). + expect(provider.calls).toBe(1); + }); + + it("half-open probes the provider after cooldown and closes on success", async () => { + const sink = statusSink(); + let now = 1_000_000; + const tick = () => now; + let attempt = 0; + const provider: LlmProvider = { + name: "openai_compatible", + async complete() { + attempt++; + if (attempt === 1) { + throw new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "401", { status: 401 }); + } + return { text: "ok", durationMs: 1 }; + }, + }; + const client = createLlmClientWithProvider( + cfg({ + onStatus: sink.push, + circuitBreaker: { enabled: true, cooldownMs: 60_000, now: tick }, + }), + provider, + ); + await expect(client.complete("trip")).rejects.toBeTruthy(); + expect(client.stats().circuitOpen).toBe(true); + // Suppressed call before cooldown elapses. + now += 30_000; + await expect(client.complete("suppressed")).rejects.toMatchObject({ + details: { circuitOpen: true }, + }); + expect(attempt).toBe(1); + // After cooldown, the next call probes the provider. + now += 31_000; // total 61_000 since trip + const r = await client.complete("probe"); + expect(r.text).toBe("ok"); + expect(attempt).toBe(2); + // Breaker closes on success. + expect(client.stats().circuitOpen).toBe(false); + }); + + it("trips on terminal primary error even when host fallback rescues the call", async () => { + const sink = statusSink(); + const provider = new ThrowingProvider( + new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "402", { status: 402 }), + ); + let hostCalls = 0; + registerHostLlmBridge({ + id: "test.host", + async complete() { + hostCalls++; + return { text: "rescued", model: "host-m", durationMs: 1 }; + }, + }); + const client = createLlmClientWithProvider( + cfg({ + fallbackToHost: true, + onStatus: sink.push, + circuitBreaker: { enabled: true }, + }), + provider, + ); + const r = await client.complete("call-1"); + expect(r.servedBy).toBe("host_fallback"); + // The terminal primary error still opens the breaker even though + // host fallback rescued the user-visible call. + expect(client.stats().circuitOpen).toBe(true); + const r2 = await client.complete("call-2"); + expect(r2.servedBy).toBe("host_fallback"); + // The second call goes directly to host fallback and never touches + // the broken paid provider again. + expect(provider.calls).toBe(1); + expect(hostCalls).toBe(2); + expect(sink.rows.map((row) => row.status)).toContain("circuit_open"); + }); + + it("disabled when circuitBreaker.enabled=false (legacy behavior)", async () => { + const provider = new ThrowingProvider( + new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "402", { status: 402 }), + ); + const client = createLlmClientWithProvider( + cfg({ circuitBreaker: { enabled: false } }), + provider, + ); + await expect(client.complete("a")).rejects.toBeTruthy(); + await expect(client.complete("b")).rejects.toBeTruthy(); + await expect(client.complete("c")).rejects.toBeTruthy(); + // All three calls reached the provider. + expect(provider.calls).toBe(3); + expect(client.stats().circuitOpen).toBe(false); + }); + + it("LlmClientStats exposes circuit fields when closed", async () => { + const fake = new FakeProvider("openai_compatible", () => ({ text: "ok", durationMs: 1 })); + const client = createLlmClientWithProvider(cfg(), fake); + await client.complete("x"); + const s = client.stats(); + expect(s.circuitOpen).toBe(false); + expect(s.circuitOpenUntil).toBeNull(); + expect(s.circuitOpenedReason).toBeNull(); + }); + + it("re-opens the breaker if the half-open probe fails terminally again", async () => { + const sink = statusSink(); + let now = 1_000_000; + const tick = () => now; + const provider = new ThrowingProvider( + new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "402", { status: 402 }), + ); + const client = createLlmClientWithProvider( + cfg({ + onStatus: sink.push, + circuitBreaker: { enabled: true, cooldownMs: 60_000, now: tick }, + }), + provider, + ); + await expect(client.complete("trip")).rejects.toBeTruthy(); + expect(client.stats().circuitOpen).toBe(true); + now += 61_000; + // Half-open probe still fails terminally → breaker re-opens. + await expect(client.complete("probe")).rejects.toBeTruthy(); + expect(client.stats().circuitOpen).toBe(true); + expect(client.stats().circuitOpenUntil).toBe(now + 60_000); + // Provider was touched twice total (initial trip + probe). + expect(provider.calls).toBe(2); + }); + }); }); diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 54f8f01e0..3ede965d3 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -1170,7 +1170,7 @@ def share_cube_with_user(self, cube_id: str, target_user_id: str) -> bool: bool: True if successful, False otherwise. """ # Validate current user has access to this cube - self._validate_cube_access(cube_id, target_user_id) + self._validate_cube_access(self.user_id, cube_id) # Validate target user exists if not self.user_manager.validate_user(target_user_id): diff --git a/tests/mem_os/test_format_utils.py b/tests/mem_os/test_format_utils.py new file mode 100644 index 000000000..b97178784 --- /dev/null +++ b/tests/mem_os/test_format_utils.py @@ -0,0 +1,75 @@ +""" +Test suite for src/memos/mem_os/utils/format_utils.py + +Focus: clean_json_response function defensive behavior +Related issue: #1525 +""" + +import pytest + +from memos.mem_os.utils.format_utils import clean_json_response + + +class TestCleanJsonResponse: + """Test clean_json_response function with various inputs.""" + + def test_clean_json_response_with_none_raises_value_error(self): + """Test that passing None raises ValueError with diagnostic message.""" + with pytest.raises(ValueError) as exc_info: + clean_json_response(None) + + error_message = str(exc_info.value) + assert "clean_json_response received None" in error_message + assert "upstream LLM call" in error_message + assert "timed_with_status" in error_message or "generate()" in error_message + + def test_clean_json_response_removes_json_code_block(self): + """Test removal of ```json markers.""" + input_str = '```json\n{"key": "value"}\n```' + expected = '{"key": "value"}' + assert clean_json_response(input_str) == expected + + def test_clean_json_response_removes_plain_code_block(self): + """Test removal of ``` markers without json keyword.""" + input_str = '```\n{"key": "value"}\n```' + expected = '{"key": "value"}' + assert clean_json_response(input_str) == expected + + def test_clean_json_response_strips_whitespace(self): + """Test that leading/trailing whitespace is stripped.""" + input_str = ' \n {"key": "value"} \n ' + expected = '{"key": "value"}' + assert clean_json_response(input_str) == expected + + def test_clean_json_response_handles_plain_json(self): + """Test that plain JSON without markdown is unchanged (except strip).""" + input_str = '{"key": "value"}' + expected = '{"key": "value"}' + assert clean_json_response(input_str) == expected + + def test_clean_json_response_handles_empty_string(self): + """Test that empty string is handled correctly.""" + assert clean_json_response("") == "" + + def test_clean_json_response_with_complex_json(self): + """Test with realistic LLM response containing nested JSON.""" + input_str = """```json +{ + "queries": [ + {"query": "test", "weight": 1.0}, + {"query": "example", "weight": 0.5} + ] +} +```""" + result = clean_json_response(input_str) + assert "```json" not in result + assert "```" not in result + assert '"queries"' in result + assert result.strip() == result # No leading/trailing whitespace + + def test_clean_json_response_preserves_internal_backticks(self): + """Test that backticks inside JSON content are preserved.""" + input_str = '```json\n{"code": "`example`"}\n```' + result = clean_json_response(input_str) + assert "`example`" in result + assert result.count("`") == 2 # Only internal backticks remain diff --git a/tests/mem_os/test_memos_core.py b/tests/mem_os/test_memos_core.py index 6d2408d05..b57b0b254 100644 --- a/tests/mem_os/test_memos_core.py +++ b/tests/mem_os/test_memos_core.py @@ -795,3 +795,146 @@ def test_search_nonexistent_cube( assert result["text_mem"] == [] assert result["act_mem"] == [] assert result["para_mem"] == [] + + +class TestShareCubeWithUser: + """Regression tests for share_cube_with_user (issue #1901). + + The original implementation called ``_validate_cube_access(cube_id, + target_user_id)``, which both (a) swapped the positional arguments and + (b) validated the wrong user. Every well-formed call therefore failed + with ``ValueError: User '' does not exist or is inactive`` even + though the calling user owned the cube. These tests pin down the correct + semantics: validate the *current* user against the cube being shared, + then delegate the share to ``user_manager.add_user_to_cube``. + """ + + def _build_mos( + self, + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ): + mock_llm_factory.from_config.return_value = mock_llm + mock_reader_factory.from_config.return_value = mock_mem_reader + mock_user_manager_class.return_value = mock_user_manager + return MOSCore(MOSConfig(**mock_config)) + + @patch("memos.mem_os.core.UserManager") + @patch("memos.mem_os.core.MemReaderFactory") + @patch("memos.mem_os.core.LLMFactory") + def test_share_cube_validates_current_user_not_target( + self, + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ): + """Cube access must be validated against the *current* user. + + Regression for #1901: previously the cube_id was passed where the + user_id was expected, causing ``_validate_user_exists`` to reject + every call because the cube UUID is obviously not a registered user. + """ + mock_user_manager.validate_user.return_value = True + mock_user_manager.validate_user_cube_access.return_value = True + mock_user_manager.add_user_to_cube.return_value = True + + mos = self._build_mos( + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ) + + cube_id = "cube-uuid-1234" + target_user_id = "target_user" + + result = mos.share_cube_with_user(cube_id=cube_id, target_user_id=target_user_id) + + assert result is True + # The cube-access check must be made against the *current* user, + # not the cube_id and not the target user. + mock_user_manager.validate_user_cube_access.assert_called_once_with(mos.user_id, cube_id) + # And the actual sharing must add the *target* user to the cube. + mock_user_manager.add_user_to_cube.assert_called_once_with(target_user_id, cube_id) + + @patch("memos.mem_os.core.UserManager") + @patch("memos.mem_os.core.MemReaderFactory") + @patch("memos.mem_os.core.LLMFactory") + def test_share_cube_raises_when_current_user_lacks_access( + self, + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ): + """If the current user doesn't have access to the cube, refuse to share. + + The error message must reference the current user, not the cube_id + (which was the misleading symptom in #1901). + """ + mock_user_manager.validate_user.return_value = True + mock_user_manager.validate_user_cube_access.return_value = False + + mos = self._build_mos( + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ) + + with pytest.raises(ValueError, match="test_user"): + mos.share_cube_with_user(cube_id="cube-uuid-1234", target_user_id="target_user") + + mock_user_manager.add_user_to_cube.assert_not_called() + + @patch("memos.mem_os.core.UserManager") + @patch("memos.mem_os.core.MemReaderFactory") + @patch("memos.mem_os.core.LLMFactory") + def test_share_cube_raises_when_target_user_missing( + self, + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ): + """Target user must exist; ``validate_user`` is consulted independently.""" + # validate_user is used twice: once during MOSCore.__init__ for + # ``self.user_id`` (must succeed) and once for the target user (fail). + mock_user_manager.validate_user.side_effect = lambda uid: uid == "test_user" + mock_user_manager.validate_user_cube_access.return_value = True + + mos = self._build_mos( + mock_llm_factory, + mock_reader_factory, + mock_user_manager_class, + mock_config, + mock_llm, + mock_mem_reader, + mock_user_manager, + ) + + with pytest.raises(ValueError, match="Target user 'missing_user'"): + mos.share_cube_with_user(cube_id="cube-uuid-1234", target_user_id="missing_user") + + mock_user_manager.add_user_to_cube.assert_not_called()