From aff0ebba1eb290d625f959660ecce4bffe6c0a8d Mon Sep 17 00:00:00 2001 From: MemOS AutoDev Date: Mon, 29 Jun 2026 23:38:50 +0800 Subject: [PATCH] fix: bound search thread-pool via shared class-level executors (#1273) Replace per-request ContextThreadPoolExecutor instantiation across the tree-text search pipeline with two shared class-level executors plus a 30 s per-future timeout. The previous code opened a fresh pool inside `with` blocks at four call sites (`_retrieve_paths`, `_retrieve_from_long_term_and_user`, `_retrieve_from_tool_memory`, `_deduplicate_rawfile_results`); whenever a worker hung on a slow Neo4j / embedding / HTTP call, `shutdown(wait=True)` never returned and the worker threads accumulated unboundedly across requests, eventually exhausting the container's pthread limit (reporter observed 8744 threads; /search returned HTTP 200 with empty results, /chat HTTP 503, even docker exec failing). Searcher now owns `_search_executor` (max_workers=10, prefix `search`) for outer A-F retrieval paths and `_search_subtask_executor` (max_workers=10, prefix `search-sub`) for the nested sub-paths and the rawfile-dedup. The two-pool split prevents nested-submission deadlock; total threads are bounded at 20 per Searcher regardless of request volume or downstream latency. Every `Future.result(...)` and the `as_completed(...)` iterator now carry `timeout=SEARCH_TASK_TIMEOUT_SECONDS` (30 s default); on TimeoutError the call site logs a warning and skips that contribution so the caller can still return. Closes #1273 --- .../tree_text_memory/retrieve/searcher.py | 256 +++++++++--- .../textual/test_tree_searcher_thread_pool.py | 392 ++++++++++++++++++ 2 files changed, 581 insertions(+), 67 deletions(-) create mode 100644 tests/memories/textual/test_tree_searcher_thread_pool.py diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index cd27d92a1..22ce8a5d8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -3,6 +3,7 @@ import re import traceback +from concurrent.futures import TimeoutError as FutureTimeoutError from concurrent.futures import as_completed from memos.context.context import ContextThreadPoolExecutor @@ -42,6 +43,19 @@ "fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, } +# Maximum time (seconds) a search subtask may block the calling request before +# we abandon waiting on it. The underlying worker thread keeps running, but it +# stays inside the bounded shared executor, so total thread count is capped. +# See issue #1273. +SEARCH_TASK_TIMEOUT_SECONDS: float = 30.0 + +# Bounds on the shared class-level executors. Two pools are used to prevent +# nested-submission starvation: outer search paths submit to +# `_search_executor`; the long-term/tool-memory/dedup paths run *on* those +# outer workers and submit their own work to `_search_subtask_executor`. +_SEARCH_EXECUTOR_MAX_WORKERS = 10 +_SEARCH_SUBTASK_EXECUTOR_MAX_WORKERS = 10 + class Searcher: def __init__( @@ -76,6 +90,20 @@ def __init__( self.manual_close_internet = manual_close_internet self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") + # Shared class-level executors for the search pipeline. They replace + # the per-request `with ContextThreadPoolExecutor(...)` pools that + # caused unbounded thread accumulation when downstream calls hung + # (issue #1273). The outer/inner split prevents nested-submission + # starvation: outer paths submit to `_search_executor`; sub-paths + # running on outer workers submit to `_search_subtask_executor`. + self._search_executor = ContextThreadPoolExecutor( + max_workers=_SEARCH_EXECUTOR_MAX_WORKERS, + thread_name_prefix="search", + ) + self._search_subtask_executor = ContextThreadPoolExecutor( + max_workers=_SEARCH_SUBTASK_EXECUTOR_MAX_WORKERS, + thread_name_prefix="search-sub", + ) def _maybe_rerank( self, @@ -379,15 +407,18 @@ def _retrieve_paths( rerank: bool = True, ): """Run A/B/C/D/E/F retrieval paths in parallel""" - tasks = [] + labelled_tasks: list[tuple[str, object]] = [] id_filter = { "user_id": info.get("user_id", None), "session_id": info.get("session_id", None), } id_filter = {k: v for k, v in id_filter.items() if v is not None} - with ContextThreadPoolExecutor(max_workers=5) as executor: - tasks.append( + # Shared, class-level executor — see Searcher.__init__ / issue #1273. + executor = self._search_executor + labelled_tasks.append( + ( + "PATH-A:working", executor.submit( self._retrieve_from_working_memory, query, @@ -400,9 +431,12 @@ def _retrieve_paths( user_name, id_filter, rerank=rerank, - ) + ), ) - tasks.append( + ) + labelled_tasks.append( + ( + "PATH-B:long_term_and_user", executor.submit( self._retrieve_from_long_term_and_user, query, @@ -416,9 +450,12 @@ def _retrieve_paths( id_filter, mode=mode, rerank=rerank, - ) + ), ) - tasks.append( + ) + labelled_tasks.append( + ( + "PATH-C:internet", executor.submit( self._retrieve_from_internet, query, @@ -430,10 +467,13 @@ def _retrieve_paths( memory_type, user_name, rerank=rerank, - ) + ), ) - if self.use_fulltext: - tasks.append( + ) + if self.use_fulltext: + labelled_tasks.append( + ( + "PATH-KEYWORD:fulltext", executor.submit( self._retrieve_from_keyword, query, @@ -446,10 +486,13 @@ def _retrieve_paths( user_name, id_filter, rerank=rerank, - ) + ), ) - if search_tool_memory: - tasks.append( + ) + if search_tool_memory: + labelled_tasks.append( + ( + "PATH-D:tool", executor.submit( self._retrieve_from_tool_memory, query, @@ -463,10 +506,13 @@ def _retrieve_paths( id_filter, mode=mode, rerank=rerank, - ) + ), ) - if include_skill_memory: - tasks.append( + ) + if include_skill_memory: + labelled_tasks.append( + ( + "PATH-E:skill", executor.submit( self._retrieve_from_skill_memory, query, @@ -480,10 +526,13 @@ def _retrieve_paths( id_filter, mode=mode, rerank=rerank, - ) + ), ) - if include_preference_memory: - tasks.append( + ) + if include_preference_memory: + labelled_tasks.append( + ( + "PATH-F:preference", executor.submit( self._retrieve_from_preference_memory, query, @@ -497,11 +546,25 @@ def _retrieve_paths( id_filter, mode=mode, rerank=rerank, - ) + ), + ) + ) + results = [] + for label, task in labelled_tasks: + try: + results.extend(task.result(timeout=SEARCH_TASK_TIMEOUT_SECONDS)) + except FutureTimeoutError: + logger.warning( + "[SEARCH] subtask %s timed out after %ss; skipping its results", + label, + SEARCH_TASK_TIMEOUT_SECONDS, + ) + except Exception: + logger.warning( + "[SEARCH] subtask %s raised an exception; skipping its results\n%s", + label, + traceback.format_exc(), ) - results = [] - for t in tasks: - results.extend(t.result()) logger.info(f"[SEARCH] Total raw results: {len(results)}") return results @@ -752,7 +815,7 @@ def _retrieve_from_long_term_and_user( ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] - tasks = [] + labelled_tasks: list[tuple[str, object]] = [] # chain of thinking cot_embeddings = [] @@ -764,9 +827,12 @@ def _retrieve_from_long_term_and_user( else: cot_embeddings = query_embedding - with ContextThreadPoolExecutor(max_workers=3) as executor: - if memory_type in ["All", "AllSummaryMemory", "LongTermMemory"]: - tasks.append( + # Shared class-level executor — see Searcher.__init__ / issue #1273. + executor = self._search_subtask_executor + if memory_type in ["All", "AllSummaryMemory", "LongTermMemory"]: + labelled_tasks.append( + ( + "PATH-B:LongTermMemory", executor.submit( self.graph_retriever.retrieve, query=query, @@ -779,10 +845,13 @@ def _retrieve_from_long_term_and_user( user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, - ) + ), ) - if memory_type in ["All", "AllSummaryMemory", "UserMemory"]: - tasks.append( + ) + if memory_type in ["All", "AllSummaryMemory", "UserMemory"]: + labelled_tasks.append( + ( + "PATH-B:UserMemory", executor.submit( self.graph_retriever.retrieve, query=query, @@ -795,10 +864,13 @@ def _retrieve_from_long_term_and_user( user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, - ) + ), ) - if memory_type in ["RawFileMemory"]: - tasks.append( + ) + if memory_type in ["RawFileMemory"]: + labelled_tasks.append( + ( + "PATH-B:RawFileMemory", executor.submit( self.graph_retriever.retrieve, query=query, @@ -811,14 +883,28 @@ def _retrieve_from_long_term_and_user( user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, - ) + ), ) + ) - # Collect results from all tasks - for task in tasks: - results.extend(task.result()) - results = self._deduplicate_rawfile_results(results, user_name=user_name) - results = self._filter_intermediate_content(results) + # Collect results from all tasks + for label, task in labelled_tasks: + try: + results.extend(task.result(timeout=SEARCH_TASK_TIMEOUT_SECONDS)) + except FutureTimeoutError: + logger.warning( + "[SEARCH] subtask %s timed out after %ss; skipping its results", + label, + SEARCH_TASK_TIMEOUT_SECONDS, + ) + except Exception: + logger.warning( + "[SEARCH] subtask %s raised an exception; skipping its results\n%s", + label, + traceback.format_exc(), + ) + results = self._deduplicate_rawfile_results(results, user_name=user_name) + results = self._filter_intermediate_content(results) return self._maybe_rerank( rerank, @@ -916,7 +1002,7 @@ def _retrieve_from_tool_memory( "ToolSchemaMemory": [], "ToolTrajectoryMemory": [], } - tasks = [] + labelled_tasks: list[tuple[str, object]] = [] # chain of thinking cot_embeddings = [] @@ -928,9 +1014,12 @@ def _retrieve_from_tool_memory( else: cot_embeddings = query_embedding - with ContextThreadPoolExecutor(max_workers=2) as executor: - if memory_type in ["All", "ToolSchemaMemory"]: - tasks.append( + # Shared class-level executor — see Searcher.__init__ / issue #1273. + executor = self._search_subtask_executor + if memory_type in ["All", "ToolSchemaMemory"]: + labelled_tasks.append( + ( + "PATH-D:ToolSchemaMemory", executor.submit( self.graph_retriever.retrieve, query=query, @@ -943,10 +1032,13 @@ def _retrieve_from_tool_memory( user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, - ) + ), ) - if memory_type in ["All", "ToolTrajectoryMemory"]: - tasks.append( + ) + if memory_type in ["All", "ToolTrajectoryMemory"]: + labelled_tasks.append( + ( + "PATH-D:ToolTrajectoryMemory", executor.submit( self.graph_retriever.retrieve, query=query, @@ -959,16 +1051,32 @@ def _retrieve_from_tool_memory( user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, - ) + ), ) + ) - # Collect results from all tasks - for task in tasks: - rsp = task.result() - if rsp and rsp[0].metadata.memory_type == "ToolSchemaMemory": - results["ToolSchemaMemory"].extend(rsp) - elif rsp and rsp[0].metadata.memory_type == "ToolTrajectoryMemory": - results["ToolTrajectoryMemory"].extend(rsp) + # Collect results from all tasks + for label, task in labelled_tasks: + try: + rsp = task.result(timeout=SEARCH_TASK_TIMEOUT_SECONDS) + except FutureTimeoutError: + logger.warning( + "[SEARCH] subtask %s timed out after %ss; skipping its results", + label, + SEARCH_TASK_TIMEOUT_SECONDS, + ) + continue + except Exception: + logger.warning( + "[SEARCH] subtask %s raised an exception; skipping its results\n%s", + label, + traceback.format_exc(), + ) + continue + if rsp and rsp[0].metadata.memory_type == "ToolSchemaMemory": + results["ToolSchemaMemory"].extend(rsp) + elif rsp and rsp[0].metadata.memory_type == "ToolTrajectoryMemory": + results["ToolTrajectoryMemory"].extend(rsp) schema_reranked = self._maybe_rerank( rerank, @@ -1298,20 +1406,23 @@ def _deduplicate_rawfile_results(self, results, user_name: str | None = None): if not rawfile_items: return results - with ContextThreadPoolExecutor(max_workers=min(len(rawfile_items), 10)) as executor: - futures = [ - executor.submit( - self.graph_store.get_edges, - rawfile_item.id, - type="SUMMARY", - direction="OUTGOING", - user_name=user_name, - ) - for rawfile_item in rawfile_items - ] - for future in as_completed(futures): + # Shared class-level executor — see Searcher.__init__ / issue #1273. + executor = self._search_subtask_executor + futures = [ + executor.submit( + self.graph_store.get_edges, + rawfile_item.id, + type="SUMMARY", + direction="OUTGOING", + user_name=user_name, + ) + for rawfile_item in rawfile_items + ] + try: + completed_iter = as_completed(futures, timeout=SEARCH_TASK_TIMEOUT_SECONDS) + for future in completed_iter: try: - edges = future.result() + edges = future.result(timeout=SEARCH_TASK_TIMEOUT_SECONDS) for edge in edges: summary_target_id = edge.get("to") if summary_target_id: @@ -1319,8 +1430,19 @@ def _deduplicate_rawfile_results(self, results, user_name: str | None = None): logger.debug( f"[DEDUP] Marking summary node {summary_target_id} for removal (pointed by RawFileMemory)" ) + except FutureTimeoutError: + logger.warning( + "[DEDUP] get_edges timed out after %ss; skipping", + SEARCH_TASK_TIMEOUT_SECONDS, + ) except Exception as e: logger.warning(f"[DEDUP] Failed to get summary target ids: {e}") + except FutureTimeoutError: + logger.warning( + "[DEDUP] one or more get_edges futures timed out after %ss; " + "returning partial dedup", + SEARCH_TASK_TIMEOUT_SECONDS, + ) filtered_results = [] for item in results: diff --git a/tests/memories/textual/test_tree_searcher_thread_pool.py b/tests/memories/textual/test_tree_searcher_thread_pool.py new file mode 100644 index 000000000..10b7e3b90 --- /dev/null +++ b/tests/memories/textual/test_tree_searcher_thread_pool.py @@ -0,0 +1,392 @@ +"""Regression tests for issue #1273 — bound search thread-pool growth. + +The Searcher used to instantiate a fresh ContextThreadPoolExecutor on every +search request inside four methods. If any worker hung on a slow Neo4j / +embedding / HTTP call, the per-request `shutdown(wait=True)` blocked +forever, the worker threads could not be reclaimed, and subsequent requests +allocated new pools — unbounded thread accumulation up to the container's +pthread limit. These tests pin the new behaviour: shared class-level +executors, bounded thread count, and a per-future timeout that does not +leak the wait to the caller. +""" + +from __future__ import annotations + +import threading +import time + +from unittest.mock import MagicMock, patch + +from memos.context.context import ContextThreadPoolExecutor +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree_text_memory.retrieve import searcher as searcher_module +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.base import BaseReranker + + +def _build_searcher() -> Searcher: + dispatcher_llm = MagicMock() + graph_store = MagicMock() + embedder = MagicMock() + reranker = MagicMock(spec=BaseReranker) + s = Searcher(dispatcher_llm, graph_store, embedder, reranker) + s.task_goal_parser = MagicMock() + s.graph_retriever = MagicMock() + s.reasoner = MagicMock() + return s + + +def _make_item(content: str) -> TextualMemoryItem: + return TextualMemoryItem( + memory=content, + metadata=TreeNodeTextualMemoryMetadata(embedding=[0.1] * 5, usage=[]), + ) + + +def _prime_simple_pipeline(s: Searcher) -> None: + parsed_goal = MagicMock() + parsed_goal.memories = ["seed"] + parsed_goal.rephrased_query = None + parsed_goal.internet_search = False + s.task_goal_parser.parse.return_value = parsed_goal + s.embedder.embed.return_value = [[0.1] * 5, [0.2] * 5] + s.graph_retriever.retrieve.return_value = [_make_item("hit")] + s.reranker.rerank.return_value = [(_make_item("hit"), 0.5)] + + +def test_search_executors_are_class_level_singletons(): + """Both shared executors must be created once per Searcher and reused.""" + s = _build_searcher() + + assert hasattr(s, "_search_executor"), "Searcher must expose a class-level _search_executor" + assert hasattr(s, "_search_subtask_executor"), ( + "Searcher must expose a class-level _search_subtask_executor" + ) + + assert isinstance(s._search_executor, ContextThreadPoolExecutor) + assert isinstance(s._search_subtask_executor, ContextThreadPoolExecutor) + + # And they must NOT be the same object — the whole point of two pools is + # to avoid nested-submission deadlock. + assert s._search_executor is not s._search_subtask_executor + + +def test_search_executor_configuration(): + """Pool sizes and thread-name prefixes are pinned to the spec.""" + s = _build_searcher() + + assert s._search_executor._max_workers == 10 + assert s._search_executor._thread_name_prefix == "search" + + assert s._search_subtask_executor._max_workers == 10 + assert s._search_subtask_executor._thread_name_prefix == "search-sub" + + +def test_no_per_request_executor_creation(): + """No new ContextThreadPoolExecutor must be built during a search call. + + This is the load-bearing regression: the bug was that every request + constructed fresh pools. We patch the symbol the searcher module + binds and assert it stays at zero invocations during _retrieve_paths. + """ + s = _build_searcher() + _prime_simple_pipeline(s) + + with patch.object( + searcher_module, "ContextThreadPoolExecutor", wraps=ContextThreadPoolExecutor + ) as ctor: + s.search( + query="anything", + top_k=1, + info={"test": True}, + mode="fast", + memory_type="WorkingMemory", + ) + # No per-request executor allowed in the four refactored methods. + assert ctor.call_count == 0, ( + f"Expected no per-request ContextThreadPoolExecutor construction; " + f"got {ctor.call_count} call(s): {ctor.call_args_list}" + ) + + +def test_thread_count_bounded_under_repeated_search(): + """16 sequential searches must not multiply 'search'-prefixed threads.""" + s = _build_searcher() + _prime_simple_pipeline(s) + + iterations = 16 + for _ in range(iterations): + s.search( + query="anything", + top_k=1, + info={"test": True}, + mode="fast", + memory_type="WorkingMemory", + ) + + # Even if the worker pool warmed all 10 slots, no more than the bound. + live_search_threads = [t for t in threading.enumerate() if t.name.startswith("search")] + search_only = [t for t in live_search_threads if t.name.startswith("search_")] + subtask_only = [t for t in live_search_threads if t.name.startswith("search-sub")] + + assert len(search_only) <= 10, ( + f"search executor leaked threads: {len(search_only)} > 10 after {iterations} requests" + ) + assert len(subtask_only) <= 10, ( + f"search-sub executor leaked threads: {len(subtask_only)} > 10 after {iterations} requests" + ) + + +def test_search_timeout_constant_present(): + """The default per-future timeout must be exposed as a module constant.""" + assert hasattr(searcher_module, "SEARCH_TASK_TIMEOUT_SECONDS") + assert isinstance(searcher_module.SEARCH_TASK_TIMEOUT_SECONDS, (int, float)) + # Reporter suggested 30 s; we accept any positive default <= 60 s. + assert 0 < searcher_module.SEARCH_TASK_TIMEOUT_SECONDS <= 60 + + +def test_retrieve_paths_passes_timeout_to_future_result(): + """`_retrieve_paths` MUST call task.result with a timeout argument.""" + s = _build_searcher() + _prime_simple_pipeline(s) + parsed_goal = s.task_goal_parser.parse.return_value + + captured: list[float | None] = [] + + real_submit = s._search_executor.submit + + def tracking_submit(fn, *args, **kwargs): + future = real_submit(fn, *args, **kwargs) + original_result = future.result + + def wrapped_result(timeout=None): + captured.append(timeout) + return original_result(timeout=timeout) + + future.result = wrapped_result # type: ignore[method-assign] + return future + + with patch.object(s._search_executor, "submit", side_effect=tracking_submit): + s._retrieve_paths( + query="q", + parsed_goal=parsed_goal, + query_embedding=[[0.1] * 5], + info={"user_id": "u", "session_id": "sess"}, + top_k=1, + mode="fast", + memory_type="WorkingMemory", + ) + + assert captured, "Expected at least one future.result() call to be tracked." + assert all(t is not None and t > 0 for t in captured), ( + f"task.result() must be called with a positive timeout; got {captured}" + ) + + +def test_hanging_subtask_does_not_block_request_forever(monkeypatch): + """A hung path must not freeze the request; warning must be logged.""" + s = _build_searcher() + _prime_simple_pipeline(s) + parsed_goal = s.task_goal_parser.parse.return_value + + # Force a tiny timeout so the test runs fast. + monkeypatch.setattr(searcher_module, "SEARCH_TASK_TIMEOUT_SECONDS", 0.5) + + block_event = threading.Event() + + def hanging_retrieve(*args, **kwargs): + # Mimic a stuck downstream call. The Event is set in the finally + # block of the test to release the thread for clean teardown. + block_event.wait(timeout=10) + return [] + + s.graph_retriever.retrieve.side_effect = hanging_retrieve + + warnings_seen: list[str] = [] + + def capture_warning(msg, *args, **kwargs): + try: + warnings_seen.append(msg % args if args else msg) + except Exception: + warnings_seen.append(str(msg)) + + monkeypatch.setattr(searcher_module.logger, "warning", capture_warning) + + started = time.monotonic() + try: + result = s._retrieve_paths( + query="q", + parsed_goal=parsed_goal, + query_embedding=[[0.1] * 5], + info={"user_id": "u", "session_id": "sess"}, + top_k=1, + mode="fast", + memory_type="WorkingMemory", + ) + finally: + block_event.set() + elapsed = time.monotonic() - started + + # Must not have waited the full 10 s; budget = timeout + slack. + assert elapsed < 5.0, f"_retrieve_paths blocked for {elapsed:.2f}s — timeout did not fire" + # Result is whatever non-timed-out paths produced (could be empty list). + assert isinstance(result, list) + assert any("timeout" in w.lower() or "timed out" in w.lower() for w in warnings_seen), ( + f"Expected a timeout warning to be logged; saw: {warnings_seen}" + ) + + +def test_searcher_executors_survive_multiple_retrieve_paths_calls(): + """Identity check: same executor instance across calls (no rebuild).""" + s = _build_searcher() + _prime_simple_pipeline(s) + + outer_ref = s._search_executor + inner_ref = s._search_subtask_executor + + for _ in range(3): + s.search( + query="anything", + top_k=1, + info={"test": True}, + mode="fast", + memory_type="WorkingMemory", + ) + + assert s._search_executor is outer_ref + assert s._search_subtask_executor is inner_ref + + +def test_long_term_path_uses_subtask_executor(monkeypatch): + """`_retrieve_from_long_term_and_user` MUST submit to the subtask pool.""" + s = _build_searcher() + _prime_simple_pipeline(s) + parsed_goal = s.task_goal_parser.parse.return_value + parsed_goal.context = [] + + submitted_executors: list[ContextThreadPoolExecutor] = [] + real_submit = s._search_subtask_executor.submit + + def tracking_submit(fn, *args, **kwargs): + submitted_executors.append(s._search_subtask_executor) + return real_submit(fn, *args, **kwargs) + + with patch.object(s._search_subtask_executor, "submit", side_effect=tracking_submit): + s._retrieve_from_long_term_and_user( + query="q", + parsed_goal=parsed_goal, + query_embedding=[[0.1] * 5], + top_k=1, + memory_type="LongTermMemory", + id_filter=None, + mode="fast", + ) + + assert submitted_executors, ( + "_retrieve_from_long_term_and_user must submit to the shared subtask " + "executor, not a fresh per-call pool" + ) + + +def test_tool_memory_path_uses_subtask_executor(): + """`_retrieve_from_tool_memory` MUST submit to the subtask pool.""" + s = _build_searcher() + _prime_simple_pipeline(s) + parsed_goal = s.task_goal_parser.parse.return_value + parsed_goal.context = [] + + tool_item = TextualMemoryItem( + memory="tool", + metadata=TreeNodeTextualMemoryMetadata( + embedding=[0.1] * 5, usage=[], memory_type="ToolSchemaMemory" + ), + ) + s.graph_retriever.retrieve.return_value = [tool_item] + + submitted = [] + real_submit = s._search_subtask_executor.submit + + def tracking_submit(fn, *args, **kwargs): + submitted.append(1) + return real_submit(fn, *args, **kwargs) + + with patch.object(s._search_subtask_executor, "submit", side_effect=tracking_submit): + s._retrieve_from_tool_memory( + query="q", + parsed_goal=parsed_goal, + query_embedding=[[0.1] * 5], + top_k=1, + memory_type="ToolSchemaMemory", + id_filter=None, + mode="fast", + ) + + assert submitted, "_retrieve_from_tool_memory must submit to the shared subtask executor" + + +def test_dedup_rawfile_uses_subtask_executor(): + """`_deduplicate_rawfile_results` MUST submit to the subtask pool when + there are RawFileMemory items to inspect.""" + s = _build_searcher() + s.graph_store.get_edges = MagicMock(return_value=[]) + + rawfile_item = TextualMemoryItem( + memory="raw", + metadata=TreeNodeTextualMemoryMetadata( + embedding=[0.1] * 5, usage=[], memory_type="RawFileMemory" + ), + ) + + submitted = [] + real_submit = s._search_subtask_executor.submit + + def tracking_submit(fn, *args, **kwargs): + submitted.append(1) + return real_submit(fn, *args, **kwargs) + + with patch.object(s._search_subtask_executor, "submit", side_effect=tracking_submit): + s._deduplicate_rawfile_results([rawfile_item], user_name="u") + + assert submitted, ( + "_deduplicate_rawfile_results must submit to the shared subtask " + "executor when RawFileMemory items are present" + ) + + +def test_dedup_rawfile_timeout_returns_partial_results(monkeypatch): + """If get_edges hangs, dedup MUST not block the caller forever.""" + s = _build_searcher() + monkeypatch.setattr(searcher_module, "SEARCH_TASK_TIMEOUT_SECONDS", 0.5) + + block_event = threading.Event() + + def hanging_get_edges(*args, **kwargs): + block_event.wait(timeout=10) + return [] + + s.graph_store.get_edges = MagicMock(side_effect=hanging_get_edges) + + rawfile_item = TextualMemoryItem( + memory="raw", + metadata=TreeNodeTextualMemoryMetadata( + embedding=[0.1] * 5, usage=[], memory_type="RawFileMemory" + ), + ) + + warnings_seen: list[str] = [] + monkeypatch.setattr( + searcher_module.logger, + "warning", + lambda msg, *a, **kw: warnings_seen.append(msg % a if a else str(msg)), + ) + + started = time.monotonic() + try: + out = s._deduplicate_rawfile_results([rawfile_item], user_name="u") + finally: + block_event.set() + elapsed = time.monotonic() - started + + assert elapsed < 5.0, f"dedup blocked for {elapsed:.2f}s — timeout did not fire" + # On timeout, original results are returned unfiltered (no edge data → no removal). + assert out == [rawfile_item]