Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ def add_node(
n += $metadata
"""

# serialization
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
# Note: `metadata["sources"]` is already JSON-serialized inside
# `_prepare_node_metadata` above. Do NOT re-encode here — doing so
# double-serializes the value and breaks `_parse_node` on read
# (see issue #1360).

with self.driver.session(database=self.db_name) as session:
session.run(
Expand Down Expand Up @@ -323,10 +323,9 @@ def add_nodes_batch(
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")

# Serialization for sources
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
# Note: `metadata["sources"]` is already JSON-serialized in
# `_prepare_node_metadata` above; do NOT re-encode here.
# See issue #1360 for the double-serialization regression.

prepared_nodes.append(
{
Expand Down
8 changes: 4 additions & 4 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def add_node(
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")

# serialization
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
# Note: `metadata["sources"]` is already JSON-serialized in
# `_prepare_node_metadata` above; do NOT re-encode here, otherwise
# the value is doubly serialized and `_parse_node` cannot decode it
# (see issue #1360).
# Extract required fields
embedding = metadata.pop("embedding", None)

Expand Down
115 changes: 115 additions & 0 deletions tests/graph_dbs/test_neo4j_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is used for maximum efficiency.
"""

import json
import math
import os
import uuid
Expand Down Expand Up @@ -245,6 +246,120 @@ def test_parse_node_without_sources_key(self, shared_neo4j_db):
assert result["memory"] == "hello"


# ──────────────────────────────────────────────────────────────────────────────
# Unit tests: sources double-serialization regression (issue #1360 follow-up)
#
# `_prepare_node_metadata` JSON-encodes the sources list once. `add_node` /
# `add_nodes_batch` must NOT re-encode it, otherwise the value reaching Neo4j
# is a list of escaped JSON strings ('"{\"k\": \"v\"}"') and `_parse_node`'s
# `[0] == "{"` check skips deserialization, returning escaped strings to
# callers instead of dicts.
# ──────────────────────────────────────────────────────────────────────────────


class TestSourcesDoubleSerializationRegression:
"""Verify sources are JSON-serialized exactly once on the write path."""

def _extract_metadata_from_call(self, session_mock):
"""Pull the metadata kwarg from the most recent driver.run call."""
last_call = session_mock.run.call_args_list[-1]
# add_node passes metadata as a keyword argument.
if "metadata" in last_call.kwargs:
return last_call.kwargs["metadata"]
# add_nodes_batch passes nodes=[{..., "metadata": ...}] as a kwarg.
if "nodes" in last_call.kwargs:
return last_call.kwargs["nodes"][0]["metadata"]
raise AssertionError(f"Unexpected driver.run call signature: {last_call}")

def test_add_node_serializes_sources_exactly_once(self, shared_neo4j_db):
session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
original_sources = [{"key1": "value1"}, {"key2": "value2"}]

shared_neo4j_db.add_node(
id="test-sources-1",
memory="test content",
metadata={
"memory_type": "WorkingMemory",
"embedding": [0.1, 0.2, 0.3],
"sources": list(original_sources),
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
},
)

metadata = self._extract_metadata_from_call(session_mock)
sources = metadata["sources"]
# Each element must round-trip back to the original dict in a single
# json.loads pass. Double-serialization breaks this.
for serialized, original in zip(sources, original_sources, strict=False):
assert isinstance(serialized, str)
assert serialized.startswith("{") and serialized.endswith("}"), (
f"sources element is doubly serialized: {serialized!r}"
)
assert json.loads(serialized) == original

def test_add_nodes_batch_serializes_sources_exactly_once(self, shared_neo4j_db):
session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
original_sources = [{"a": 1}, {"b": 2}]

shared_neo4j_db.add_nodes_batch(
nodes=[
{
"id": "test-batch-sources-1",
"memory": "batch content",
"metadata": {
"memory_type": "WorkingMemory",
"embedding": [0.1, 0.2, 0.3],
"sources": list(original_sources),
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
},
}
],
)

metadata = self._extract_metadata_from_call(session_mock)
sources = metadata["sources"]
for serialized, original in zip(sources, original_sources, strict=False):
assert isinstance(serialized, str)
assert serialized.startswith("{") and serialized.endswith("}"), (
f"sources element is doubly serialized: {serialized!r}"
)
assert json.loads(serialized) == original

def test_add_node_then_parse_node_round_trip(self, shared_neo4j_db):
"""add_node serializes once; _parse_node should decode back to dicts."""
session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
original_sources = [{"id": "src-1"}, {"id": "src-2"}]

shared_neo4j_db.add_node(
id="test-roundtrip-1",
memory="round trip",
metadata={
"memory_type": "WorkingMemory",
"embedding": [0.1, 0.2, 0.3],
"sources": list(original_sources),
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
},
)

stored_metadata = self._extract_metadata_from_call(session_mock)

# Simulate Neo4j returning the stored node back to the caller.
parsed = shared_neo4j_db._parse_node(
{
"id": "test-roundtrip-1",
"memory": "round trip",
"memory_type": "WorkingMemory",
"sources": list(stored_metadata["sources"]),
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
)
assert parsed["metadata"]["sources"] == original_sources


# ──────────────────────────────────────────────────────────────────────────────
# Integration tests (require a running Neo4j 5.18+ with vector index)
#
Expand Down
Loading