From eb52c66b1386c57ee8e484fca737e7c46da04153 Mon Sep 17 00:00:00 2001 From: Kapil Samant Date: Wed, 3 Jun 2026 18:07:58 +0530 Subject: [PATCH 1/3] FEAT: Set session level auditing --- mssql_python/connection.py | 169 ++++++++++++++++++++++++++++++++ tests/test_023_audit_context.py | 144 +++++++++++++++++++++++++++ 2 files changed, 313 insertions(+) create mode 100644 tests/test_023_audit_context.py diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692..abf85b22 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1113,6 +1113,175 @@ def clear_output_converters(self) -> None: if hasattr(self._conn, "clear_output_converters"): self._conn.clear_output_converters() logger.info("Cleared all output converters") + + + # ---- Session Metadata / Auditing API ---- + + # Maximum length for session context keys and values to prevent abuse. + _AUDIT_KEY_MAX_LEN: int = 128 + _AUDIT_VALUE_MAX_LEN: int = 8000 # SQL Server sp_set_session_context limit + + def set_audit_context( + self, + *, + application: Optional[str] = None, + module: Optional[str] = None, + action: Optional[str] = None, + user_id: Optional[str] = None, + read_only: bool = False, + **extra: str, + ) -> None: + """ + Set session-level auditing / tracing metadata on the current connection. + + This stores name-value pairs in the SQL Server session context via + ``sp_set_session_context``, making them visible to: + + * ``SESSION_CONTEXT()`` in T-SQL queries, triggers, and stored procedures + * Extended Events sessions that capture session context + * ``sys.dm_exec_sessions`` (for *application*) + * Audit specifications that reference session context + + All parameters are optional; only the ones provided will be set. + Calling this method again merges new values with previously-set ones; + to clear a key pass an empty string ``""``. + + Args: + application: Logical application name (sets ``application_name``). + module: Module or component name (sets ``module_name``). + action: Current action or operation (sets ``action_name``). + user_id: End-user identifier (sets ``user_id``). + read_only: If ``True``, the keys become read-only for the + remainder of the session — subsequent calls cannot change them. + **extra: Arbitrary additional key-value pairs to store in the + session context. + + Raises: + InterfaceError: If the connection is closed. + ProgrammingError: If a key or value exceeds length limits or + contains invalid characters. + DatabaseError: If ``sp_set_session_context`` execution fails. + + Example:: + + conn.set_audit_context( + application="BillingAPI", + module="InvoiceProcessor", + action="GenerateInvoice", + user_id="123", + ) + # Values are now readable in T-SQL: + # SELECT SESSION_CONTEXT(N'application_name') + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot set audit context on a closed connection", + ) + + # Build the mapping of keys to set + pairs: Dict[str, str] = {} + if application is not None: + pairs["application_name"] = application + if module is not None: + pairs["module_name"] = module + if action is not None: + pairs["action_name"] = action + if user_id is not None: + pairs["user_id"] = user_id + for key, value in extra.items(): + pairs[key] = value + + if not pairs: + return # nothing to do + + # Validate lengths + for key, value in pairs.items(): + if not isinstance(key, str) or not key: + raise ProgrammingError( + driver_error="Invalid audit context key", + ddbc_error="Session context key must be a non-empty string", + ) + if len(key) > self._AUDIT_KEY_MAX_LEN: + raise ProgrammingError( + driver_error="Audit context key too long", + ddbc_error=( + f"Session context key exceeds {self._AUDIT_KEY_MAX_LEN} characters" + ), + ) + if not isinstance(value, str): + raise ProgrammingError( + driver_error="Invalid audit context value", + ddbc_error="Session context values must be strings", + ) + if len(value) > self._AUDIT_VALUE_MAX_LEN: + raise ProgrammingError( + driver_error="Audit context value too long", + ddbc_error=( + f"Session context value exceeds {self._AUDIT_VALUE_MAX_LEN} characters" + ), + ) + + # Initialize local cache if first call + if not hasattr(self, "_audit_context"): + self._audit_context: Dict[str, str] = {} + + # Execute sp_set_session_context for each pair using parameterized queries + cursor = self.cursor() + try: + for key, value in pairs.items(): + # Empty string means "clear"; sp_set_session_context requires NULL + sql_value = None if value == "" else value + if read_only: + cursor.execute( + "EXEC sp_set_session_context @key=?, @value=?, @read_only=1", + key, + sql_value, + ) + else: + cursor.execute( + "EXEC sp_set_session_context @key=?, @value=?", + key, + sql_value, + ) + if value == "": + self._audit_context.pop(key, None) + else: + self._audit_context[key] = value + logger.debug("Set session context: %s", sanitize_user_input(key)) + finally: + cursor.close() + + logger.info( + "Audit context set with %d key(s): %s", + len(pairs), + ", ".join(sanitize_user_input(k) for k in pairs), + ) + + def get_audit_context(self) -> Dict[str, str]: + """ + Return a copy of the session audit context previously set via + :meth:`set_audit_context`. + + This returns the *locally cached* values — it does not round-trip to + the server. To verify server-side values, query + ``SESSION_CONTEXT(N'')`` directly. + + Returns: + dict: A ``{key: value}`` mapping of the current session context. + + Raises: + InterfaceError: If the connection is closed. + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot get audit context on a closed connection", + ) + if not hasattr(self, "_audit_context"): + return {} + return dict(self._audit_context) + def execute(self, sql: str, *args: Any) -> Cursor: """ diff --git a/tests/test_023_audit_context.py b/tests/test_023_audit_context.py new file mode 100644 index 00000000..39c91c91 --- /dev/null +++ b/tests/test_023_audit_context.py @@ -0,0 +1,144 @@ +""" +Tests for the session metadata / auditing API (set_audit_context / get_audit_context). + +Functions: +- test_set_and_get_audit_context: Set named fields and verify local cache. +- test_audit_context_server_roundtrip: Verify values are readable via SESSION_CONTEXT(). +- test_audit_context_extra_keys: Test arbitrary extra key-value pairs. +- test_audit_context_merge: Successive calls merge, not replace. +- test_audit_context_empty_call: Calling with no arguments is a no-op. +- test_audit_context_clear_value: Setting a key to "" clears it server-side. +- test_audit_context_read_only: read_only=True prevents subsequent changes. +- test_audit_context_closed_connection: Raises InterfaceError when connection is closed. +- test_audit_context_key_too_long: Raises ProgrammingError for oversized keys. +- test_audit_context_value_too_long: Raises ProgrammingError for oversized values. +- test_audit_context_non_string_value: Raises ProgrammingError for non-string values. +""" + +import pytest +from mssql_python import connect +from mssql_python.exceptions import InterfaceError, ProgrammingError, DatabaseError + + +@pytest.fixture() +def audit_conn(conn_str): + """Dedicated connection for audit context tests (module-scoped fixtures + would share session state, so we create a fresh connection per test).""" + conn = connect(conn_str) + yield conn + conn.close() + + +class TestAuditContext: + """Tests for Connection.set_audit_context / get_audit_context.""" + + def test_set_and_get_audit_context(self, audit_conn): + """Named fields are reflected in the local cache.""" + audit_conn.set_audit_context( + application="BillingAPI", + module="InvoiceProcessor", + action="GenerateInvoice", + user_id="123", + ) + ctx = audit_conn.get_audit_context() + assert ctx["application_name"] == "BillingAPI" + assert ctx["module_name"] == "InvoiceProcessor" + assert ctx["action_name"] == "GenerateInvoice" + assert ctx["user_id"] == "123" + + def test_audit_context_server_roundtrip(self, audit_conn): + """Values set via set_audit_context are readable with SESSION_CONTEXT().""" + audit_conn.set_audit_context(application="RoundTrip", user_id="42") + cursor = audit_conn.cursor() + try: + cursor.execute("SELECT SESSION_CONTEXT(N'application_name')") + row = cursor.fetchone() + assert row[0] == "RoundTrip" + + cursor.execute("SELECT SESSION_CONTEXT(N'user_id')") + row = cursor.fetchone() + assert row[0] == "42" + finally: + cursor.close() + + def test_audit_context_extra_keys(self, audit_conn): + """Arbitrary extra keys are stored via sp_set_session_context.""" + audit_conn.set_audit_context(tenant_id="ACME", correlation_id="abc-def") + ctx = audit_conn.get_audit_context() + assert ctx["tenant_id"] == "ACME" + assert ctx["correlation_id"] == "abc-def" + + # Verify server-side + cursor = audit_conn.cursor() + try: + cursor.execute("SELECT SESSION_CONTEXT(N'tenant_id')") + assert cursor.fetchone()[0] == "ACME" + finally: + cursor.close() + + def test_audit_context_merge(self, audit_conn): + """Successive calls merge values, not replace.""" + audit_conn.set_audit_context(application="App1") + audit_conn.set_audit_context(module="Mod1") + ctx = audit_conn.get_audit_context() + assert ctx["application_name"] == "App1" + assert ctx["module_name"] == "Mod1" + + def test_audit_context_overwrite(self, audit_conn): + """A second call with the same key overwrites the previous value.""" + audit_conn.set_audit_context(action="First") + audit_conn.set_audit_context(action="Second") + assert audit_conn.get_audit_context()["action_name"] == "Second" + + def test_audit_context_empty_call(self, audit_conn): + """Calling with no arguments is a silent no-op.""" + audit_conn.set_audit_context() + assert audit_conn.get_audit_context() == {} + + def test_audit_context_clear_value(self, audit_conn): + """Setting a key to '' clears it (sends NULL to the server).""" + audit_conn.set_audit_context(user_id="99") + audit_conn.set_audit_context(user_id="") + assert "user_id" not in audit_conn.get_audit_context() + + def test_audit_context_read_only(self, audit_conn): + """read_only=True makes the key immutable for the session.""" + audit_conn.set_audit_context(action="Locked", read_only=True) + # Attempting to change a read-only key should raise a DatabaseError + # from SQL Server (error 15664). + with pytest.raises(DatabaseError): + audit_conn.set_audit_context(action="Changed") + + def test_audit_context_closed_connection_set(self, audit_conn): + """set_audit_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.set_audit_context(application="X") + + def test_audit_context_closed_connection_get(self, audit_conn): + """get_audit_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.get_audit_context() + + def test_audit_context_key_too_long(self, audit_conn): + """Keys longer than 128 characters are rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_audit_context(**{"x" * 200: "v"}) + + def test_audit_context_value_too_long(self, audit_conn): + """Values longer than 8000 characters are rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_audit_context(user_id="v" * 8001) + + def test_audit_context_non_string_value(self, audit_conn): + """Non-string values are rejected with ProgrammingError.""" + with pytest.raises((ProgrammingError, TypeError)): + audit_conn.set_audit_context(user_id=123) # type: ignore[arg-type] + + def test_get_audit_context_returns_copy(self, audit_conn): + """get_audit_context returns a copy, not the internal dict.""" + audit_conn.set_audit_context(application="Copy") + ctx = audit_conn.get_audit_context() + ctx["application_name"] = "Mutated" + assert audit_conn.get_audit_context()["application_name"] == "Copy" From 1d68115a8cbc35aed8ffcc19155b0f370129d755 Mon Sep 17 00:00:00 2001 From: Kapil Samant Date: Wed, 3 Jun 2026 22:42:10 +0530 Subject: [PATCH 2/3] FEAT: Set session level auditing - fix tests --- mssql_python/connection.py | 9 +++------ tests/test_023_audit_context.py | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index abf85b22..4578cc73 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1113,13 +1113,12 @@ def clear_output_converters(self) -> None: if hasattr(self._conn, "clear_output_converters"): self._conn.clear_output_converters() logger.info("Cleared all output converters") - - + # ---- Session Metadata / Auditing API ---- # Maximum length for session context keys and values to prevent abuse. _AUDIT_KEY_MAX_LEN: int = 128 - _AUDIT_VALUE_MAX_LEN: int = 8000 # SQL Server sp_set_session_context limit + _AUDIT_VALUE_MAX_LEN: int = 4000 def set_audit_context( self, @@ -1158,8 +1157,7 @@ def set_audit_context( Raises: InterfaceError: If the connection is closed. - ProgrammingError: If a key or value exceeds length limits or - contains invalid characters. + ProgrammingError: If a key or value exceeds length limits DatabaseError: If ``sp_set_session_context`` execution fails. Example:: @@ -1281,7 +1279,6 @@ def get_audit_context(self) -> Dict[str, str]: if not hasattr(self, "_audit_context"): return {} return dict(self._audit_context) - def execute(self, sql: str, *args: Any) -> Cursor: """ diff --git a/tests/test_023_audit_context.py b/tests/test_023_audit_context.py index 39c91c91..8c1c723d 100644 --- a/tests/test_023_audit_context.py +++ b/tests/test_023_audit_context.py @@ -127,13 +127,13 @@ def test_audit_context_key_too_long(self, audit_conn): audit_conn.set_audit_context(**{"x" * 200: "v"}) def test_audit_context_value_too_long(self, audit_conn): - """Values longer than 8000 characters are rejected.""" + """Values longer than 4000 characters are rejected.""" with pytest.raises(ProgrammingError): - audit_conn.set_audit_context(user_id="v" * 8001) + audit_conn.set_audit_context(user_id="v" * 4001) def test_audit_context_non_string_value(self, audit_conn): """Non-string values are rejected with ProgrammingError.""" - with pytest.raises((ProgrammingError, TypeError)): + with pytest.raises(ProgrammingError): audit_conn.set_audit_context(user_id=123) # type: ignore[arg-type] def test_get_audit_context_returns_copy(self, audit_conn): From a7123e8d210cca33d6ef6cd18a48bf61e50a912b Mon Sep 17 00:00:00 2001 From: Kapil Samant Date: Fri, 5 Jun 2026 18:34:52 +0530 Subject: [PATCH 3/3] FEAT: Set session level auditing - apply review comments --- mssql_python/connection.py | 302 ++++++++++++++++++++++---------- tests/test_023_audit_context.py | 255 ++++++++++++++++++++------- 2 files changed, 399 insertions(+), 158 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 4578cc73..01730c4b 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1114,171 +1114,270 @@ def clear_output_converters(self) -> None: self._conn.clear_output_converters() logger.info("Cleared all output converters") - # ---- Session Metadata / Auditing API ---- + # ---- Session Context API ---- - # Maximum length for session context keys and values to prevent abuse. - _AUDIT_KEY_MAX_LEN: int = 128 - _AUDIT_VALUE_MAX_LEN: int = 4000 + _SESSION_KEY_MAX_LEN: int = 128 + _SESSION_VALUE_MAX_LEN: int = 4000 - def set_audit_context( + def set_session_context( self, *, - application: Optional[str] = None, - module: Optional[str] = None, - action: Optional[str] = None, - user_id: Optional[str] = None, read_only: bool = False, - **extra: str, + **context: Optional[str], ) -> None: """ - Set session-level auditing / tracing metadata on the current connection. + Set session-level metadata on the current connection. This stores name-value pairs in the SQL Server session context via ``sp_set_session_context``, making them visible to: * ``SESSION_CONTEXT()`` in T-SQL queries, triggers, and stored procedures * Extended Events sessions that capture session context - * ``sys.dm_exec_sessions`` (for *application*) + * ``sys.dm_exec_sessions`` (for *application_name*) * Audit specifications that reference session context - All parameters are optional; only the ones provided will be set. - Calling this method again merges new values with previously-set ones; - to clear a key pass an empty string ``""``. + Only keys that are passed will be set. Calling this method again + merges new values with previously-set ones; to clear a key pass + ``None`` as its value. + + Well-known keys (optional, not enforced): + ``application_name``, ``module_name``, ``action_name``, ``user_id`` Args: - application: Logical application name (sets ``application_name``). - module: Module or component name (sets ``module_name``). - action: Current action or operation (sets ``action_name``). - user_id: End-user identifier (sets ``user_id``). read_only: If ``True``, the keys become read-only for the remainder of the session — subsequent calls cannot change them. - **extra: Arbitrary additional key-value pairs to store in the - session context. + **context: Key-value pairs to store in the session context. + Pass ``None`` as a value to clear a key. Raises: InterfaceError: If the connection is closed. - ProgrammingError: If a key or value exceeds length limits + ProgrammingError: If a key or value exceeds length limits. DatabaseError: If ``sp_set_session_context`` execution fails. Example:: - conn.set_audit_context( - application="BillingAPI", - module="InvoiceProcessor", - action="GenerateInvoice", + conn.set_session_context( + application_name="BillingAPI", + module_name="InvoiceProcessor", + action_name="GenerateInvoice", user_id="123", ) # Values are now readable in T-SQL: # SELECT SESSION_CONTEXT(N'application_name') + + # Clear a key: + conn.set_session_context(user_id=None) """ if self._closed: raise InterfaceError( driver_error="Connection is closed", - ddbc_error="Cannot set audit context on a closed connection", + ddbc_error="Cannot set session context on a closed connection", ) - # Build the mapping of keys to set - pairs: Dict[str, str] = {} - if application is not None: - pairs["application_name"] = application - if module is not None: - pairs["module_name"] = module - if action is not None: - pairs["action_name"] = action - if user_id is not None: - pairs["user_id"] = user_id - for key, value in extra.items(): - pairs[key] = value - - if not pairs: + if not context: return # nothing to do # Validate lengths - for key, value in pairs.items(): + for key, value in context.items(): if not isinstance(key, str) or not key: raise ProgrammingError( - driver_error="Invalid audit context key", + driver_error="Invalid session context key", ddbc_error="Session context key must be a non-empty string", ) - if len(key) > self._AUDIT_KEY_MAX_LEN: + if len(key) > self._SESSION_KEY_MAX_LEN: raise ProgrammingError( - driver_error="Audit context key too long", + driver_error="Session context key too long", ddbc_error=( - f"Session context key exceeds {self._AUDIT_KEY_MAX_LEN} characters" + f"Session context key exceeds {self._SESSION_KEY_MAX_LEN} characters" ), ) - if not isinstance(value, str): - raise ProgrammingError( - driver_error="Invalid audit context value", - ddbc_error="Session context values must be strings", - ) - if len(value) > self._AUDIT_VALUE_MAX_LEN: + if value is not None: + if not isinstance(value, str): + raise ProgrammingError( + driver_error="Invalid session context value", + ddbc_error="Session context values must be strings or None", + ) + if len(value) > self._SESSION_VALUE_MAX_LEN: + raise ProgrammingError( + driver_error="Session context value too long", + ddbc_error=( + f"Session context value exceeds {self._SESSION_VALUE_MAX_LEN} characters" + ), + ) + + # Initialize local cache if first call + if not hasattr(self, "_session_context"): + self._session_context: Dict[str, str] = {} + if not hasattr(self, "_session_context_read_only_keys"): + self._session_context_read_only_keys: set = set() + + # Reject attempts to clear read-only keys via value=None + for key, value in context.items(): + if value is None and key in self._session_context_read_only_keys: raise ProgrammingError( - driver_error="Audit context value too long", + driver_error="Cannot clear read-only session context key", ddbc_error=( - f"Session context value exceeds {self._AUDIT_VALUE_MAX_LEN} characters" + f"Session context key '{key}' was set with read_only=True " + "and cannot be cleared" ), ) - # Initialize local cache if first call - if not hasattr(self, "_audit_context"): - self._audit_context: Dict[str, str] = {} + # Build a single batch of sp_set_session_context calls to execute + # in one round trip instead of N separate calls. + batch_parts: list[str] = [] + batch_params: list = [] + for key, value in context.items(): + if read_only: + batch_parts.append( + "EXEC sp_set_session_context @key=?, @value=?, @read_only=1" + ) + else: + batch_parts.append( + "EXEC sp_set_session_context @key=?, @value=?" + ) + batch_params.append(key) + batch_params.append(value) - # Execute sp_set_session_context for each pair using parameterized queries cursor = self.cursor() try: - for key, value in pairs.items(): - # Empty string means "clear"; sp_set_session_context requires NULL - sql_value = None if value == "" else value - if read_only: - cursor.execute( - "EXEC sp_set_session_context @key=?, @value=?, @read_only=1", - key, - sql_value, - ) - else: - cursor.execute( - "EXEC sp_set_session_context @key=?, @value=?", - key, - sql_value, - ) - if value == "": - self._audit_context.pop(key, None) - else: - self._audit_context[key] = value - logger.debug("Set session context: %s", sanitize_user_input(key)) + cursor.execute("; ".join(batch_parts), *batch_params) finally: cursor.close() + # Update local cache after successful execution + for key, value in context.items(): + if value is None: + self._session_context.pop(key, None) + else: + self._session_context[key] = value + if read_only: + self._session_context_read_only_keys.add(key) + logger.debug("Set session context: %s", sanitize_user_input(key)) + logger.info( - "Audit context set with %d key(s): %s", - len(pairs), - ", ".join(sanitize_user_input(k) for k in pairs), + "Session context set with %d key(s): %s", + len(context), + ", ".join(sanitize_user_input(k) for k in context), ) - def get_audit_context(self) -> Dict[str, str]: + def get_session_context(self) -> Dict[str, str]: """ - Return a copy of the session audit context previously set via - :meth:`set_audit_context`. + Return the current session context for keys previously set via + :meth:`set_session_context`. - This returns the *locally cached* values — it does not round-trip to - the server. To verify server-side values, query - ``SESSION_CONTEXT(N'')`` directly. + For each known key the driver queries the server with + ``SELECT SESSION_CONTEXT(N'')``, so the returned values + reflect any server-side mutations (e.g. by triggers or stored + procedures) that occurred after the initial ``set_session_context`` + call. Returns: - dict: A ``{key: value}`` mapping of the current session context. + dict: A ``{key: value}`` mapping. Keys whose server-side + value is ``NULL`` are omitted. Raises: InterfaceError: If the connection is closed. + DatabaseError: If the server query fails. """ if self._closed: raise InterfaceError( driver_error="Connection is closed", - ddbc_error="Cannot get audit context on a closed connection", + ddbc_error="Cannot get session context on a closed connection", ) - if not hasattr(self, "_audit_context"): + if not hasattr(self, "_session_context") or not self._session_context: return {} - return dict(self._audit_context) + + # Query each known key from the server in one batch using + # parameterised queries to avoid SQL injection. + # SESSION_CONTEXT() requires nvarchar, so we CAST the parameter. + keys = list(self._session_context.keys()) + select_parts = ["SELECT SESSION_CONTEXT(CAST(? AS nvarchar(128)))" for _ in keys] + cursor = self.cursor() + try: + cursor.execute("; ".join(select_parts), *keys) + result: Dict[str, str] = {} + for key in keys: + row = cursor.fetchone() + if row is not None and row[0] is not None: + result[key] = str(row[0]) + # Advance to next result set for the next SELECT. + if not cursor.nextset(): + break + return result + finally: + cursor.close() + + def clear_session_context(self, *keys: str) -> None: + """ + Clear one or more session context keys by sending ``NULL`` to the server. + + If no keys are provided, all non-read-only keys that were previously + set via :meth:`set_session_context` are cleared. + + Args: + *keys: Key names to clear. If omitted, all clearable keys are + cleared. + + Raises: + InterfaceError: If the connection is closed. + ProgrammingError: If a specified key was set with ``read_only=True``. + DatabaseError: If ``sp_set_session_context`` execution fails. + + Example:: + + conn.clear_session_context("user_id") # clear one key + conn.clear_session_context("user_id", "action_name") # clear several + conn.clear_session_context() # clear all + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot clear session context on a closed connection", + ) + if not hasattr(self, "_session_context") or not self._session_context: + return + + read_only_keys: set = getattr(self, "_session_context_read_only_keys", set()) + + if keys: + # Validate that none of the requested keys are read-only. + read_only_requested = read_only_keys & set(keys) + if read_only_requested: + raise ProgrammingError( + driver_error="Cannot clear read-only session context keys", + ddbc_error=( + "The following keys are read-only and cannot be cleared: " + + ", ".join(sorted(read_only_requested)) + ), + ) + to_clear = [k for k in keys if k in self._session_context] + else: + to_clear = [k for k in self._session_context if k not in read_only_keys] + + if not to_clear: + return + + # Build a batch to NULL-out every key in one round trip. + batch_parts: list[str] = [] + batch_params: list = [] + for key in to_clear: + batch_parts.append("EXEC sp_set_session_context @key=?, @value=NULL") + batch_params.append(key) + + cursor = self.cursor() + try: + cursor.execute("; ".join(batch_parts), *batch_params) + finally: + cursor.close() + + for key in to_clear: + self._session_context.pop(key, None) + + logger.info( + "Cleared %d session context key(s): %s", + len(to_clear), + ", ".join(sanitize_user_input(k) for k in to_clear), + ) def execute(self, sql: str, *args: Any) -> Cursor: """ @@ -1802,6 +1901,25 @@ def close(self) -> None: # references self._cursors.clear() + # If pooling is enabled, clear session context so the next consumer + # of this ODBC handle does not inherit leftover metadata. + if self._pooling and hasattr(self, "_session_context") and self._session_context: + read_only_keys = getattr(self, "_session_context_read_only_keys", set()) + if read_only_keys & set(self._session_context): + logger.warning( + "Pooled connection has read-only session context keys that " + "cannot be cleared: %s. These values will persist for the " + "next consumer of this connection.", + ", ".join(sorted(read_only_keys & set(self._session_context))), + ) + try: + self.clear_session_context() + except Exception: + logger.warning( + "Failed to clear session context before pool return", + exc_info=True, + ) + # Close the connection even if cursor cleanup had issues try: if self._conn: @@ -1887,4 +2005,4 @@ def __del__(self) -> None: self.close() except Exception as e: # Dont raise exceptions from __del__ to avoid issues during garbage collection - logger.warning(f"Error during connection cleanup: {e}") + logger.warning(f"Error during connection cleanup: {e}") \ No newline at end of file diff --git a/tests/test_023_audit_context.py b/tests/test_023_audit_context.py index 8c1c723d..c5c6ac80 100644 --- a/tests/test_023_audit_context.py +++ b/tests/test_023_audit_context.py @@ -1,18 +1,27 @@ """ -Tests for the session metadata / auditing API (set_audit_context / get_audit_context). +Tests for the session context API +(set_session_context / get_session_context / clear_session_context). Functions: -- test_set_and_get_audit_context: Set named fields and verify local cache. -- test_audit_context_server_roundtrip: Verify values are readable via SESSION_CONTEXT(). -- test_audit_context_extra_keys: Test arbitrary extra key-value pairs. -- test_audit_context_merge: Successive calls merge, not replace. -- test_audit_context_empty_call: Calling with no arguments is a no-op. -- test_audit_context_clear_value: Setting a key to "" clears it server-side. -- test_audit_context_read_only: read_only=True prevents subsequent changes. -- test_audit_context_closed_connection: Raises InterfaceError when connection is closed. -- test_audit_context_key_too_long: Raises ProgrammingError for oversized keys. -- test_audit_context_value_too_long: Raises ProgrammingError for oversized values. -- test_audit_context_non_string_value: Raises ProgrammingError for non-string values. +- test_set_and_get_session_context: Set named fields and verify via server query. +- test_session_context_server_roundtrip: Verify values are readable via SESSION_CONTEXT(). +- test_session_context_extra_keys: Test arbitrary extra key-value pairs. +- test_session_context_merge: Successive calls merge, not replace. +- test_session_context_empty_call: Calling with no arguments is a no-op. +- test_session_context_clear_value: Setting a key to None clears it server-side. +- test_session_context_read_only: read_only=True prevents subsequent changes. +- test_session_context_closed_connection: Raises InterfaceError when connection is closed. +- test_session_context_key_too_long: Raises ProgrammingError for oversized keys. +- test_session_context_value_too_long: Raises ProgrammingError for oversized values. +- test_session_context_non_string_value: Raises ProgrammingError for non-string values. +- test_clear_session_context_single_key: Clear one key. +- test_clear_session_context_multiple_keys: Clear several keys. +- test_clear_session_context_all: Clear all keys. +- test_clear_session_context_read_only_raises: Clearing a read-only key raises ProgrammingError. +- test_clear_session_context_closed: Raises InterfaceError on closed connection. +- test_clear_session_context_noop: No-op when nothing has been set. +- test_get_session_context_reflects_server: Getter fetches live server values. +- test_pool_return_clears_context: Session context is cleared when pooled connection is closed. """ import pytest @@ -22,33 +31,33 @@ @pytest.fixture() def audit_conn(conn_str): - """Dedicated connection for audit context tests (module-scoped fixtures + """Dedicated connection for session context tests (module-scoped fixtures would share session state, so we create a fresh connection per test).""" conn = connect(conn_str) yield conn conn.close() -class TestAuditContext: - """Tests for Connection.set_audit_context / get_audit_context.""" +class TestSessionContext: + """Tests for Connection.set_session_context / get_session_context.""" - def test_set_and_get_audit_context(self, audit_conn): + def test_set_and_get_session_context(self, audit_conn): """Named fields are reflected in the local cache.""" - audit_conn.set_audit_context( - application="BillingAPI", - module="InvoiceProcessor", - action="GenerateInvoice", + audit_conn.set_session_context( + application_name="BillingAPI", + module_name="InvoiceProcessor", + action_name="GenerateInvoice", user_id="123", ) - ctx = audit_conn.get_audit_context() + ctx = audit_conn.get_session_context() assert ctx["application_name"] == "BillingAPI" assert ctx["module_name"] == "InvoiceProcessor" assert ctx["action_name"] == "GenerateInvoice" assert ctx["user_id"] == "123" - def test_audit_context_server_roundtrip(self, audit_conn): - """Values set via set_audit_context are readable with SESSION_CONTEXT().""" - audit_conn.set_audit_context(application="RoundTrip", user_id="42") + def test_session_context_server_roundtrip(self, audit_conn): + """Values set via set_session_context are readable with SESSION_CONTEXT().""" + audit_conn.set_session_context(application_name="RoundTrip", user_id="42") cursor = audit_conn.cursor() try: cursor.execute("SELECT SESSION_CONTEXT(N'application_name')") @@ -61,10 +70,10 @@ def test_audit_context_server_roundtrip(self, audit_conn): finally: cursor.close() - def test_audit_context_extra_keys(self, audit_conn): + def test_session_context_extra_keys(self, audit_conn): """Arbitrary extra keys are stored via sp_set_session_context.""" - audit_conn.set_audit_context(tenant_id="ACME", correlation_id="abc-def") - ctx = audit_conn.get_audit_context() + audit_conn.set_session_context(tenant_id="ACME", correlation_id="abc-def") + ctx = audit_conn.get_session_context() assert ctx["tenant_id"] == "ACME" assert ctx["correlation_id"] == "abc-def" @@ -76,69 +85,183 @@ def test_audit_context_extra_keys(self, audit_conn): finally: cursor.close() - def test_audit_context_merge(self, audit_conn): + def test_session_context_merge(self, audit_conn): """Successive calls merge values, not replace.""" - audit_conn.set_audit_context(application="App1") - audit_conn.set_audit_context(module="Mod1") - ctx = audit_conn.get_audit_context() + audit_conn.set_session_context(application_name="App1") + audit_conn.set_session_context(module_name="Mod1") + ctx = audit_conn.get_session_context() assert ctx["application_name"] == "App1" assert ctx["module_name"] == "Mod1" - def test_audit_context_overwrite(self, audit_conn): + def test_session_context_overwrite(self, audit_conn): """A second call with the same key overwrites the previous value.""" - audit_conn.set_audit_context(action="First") - audit_conn.set_audit_context(action="Second") - assert audit_conn.get_audit_context()["action_name"] == "Second" + audit_conn.set_session_context(action_name="First") + audit_conn.set_session_context(action_name="Second") + assert audit_conn.get_session_context()["action_name"] == "Second" - def test_audit_context_empty_call(self, audit_conn): + def test_session_context_empty_call(self, audit_conn): """Calling with no arguments is a silent no-op.""" - audit_conn.set_audit_context() - assert audit_conn.get_audit_context() == {} + audit_conn.set_session_context() + assert audit_conn.get_session_context() == {} - def test_audit_context_clear_value(self, audit_conn): - """Setting a key to '' clears it (sends NULL to the server).""" - audit_conn.set_audit_context(user_id="99") - audit_conn.set_audit_context(user_id="") - assert "user_id" not in audit_conn.get_audit_context() + def test_session_context_clear_value(self, audit_conn): + """Setting a key to None clears it (sends NULL to the server).""" + audit_conn.set_session_context(user_id="99") + audit_conn.set_session_context(user_id=None) + assert "user_id" not in audit_conn.get_session_context() - def test_audit_context_read_only(self, audit_conn): + def test_session_context_read_only(self, audit_conn): """read_only=True makes the key immutable for the session.""" - audit_conn.set_audit_context(action="Locked", read_only=True) + audit_conn.set_session_context(action_name="Locked", read_only=True) # Attempting to change a read-only key should raise a DatabaseError # from SQL Server (error 15664). with pytest.raises(DatabaseError): - audit_conn.set_audit_context(action="Changed") + audit_conn.set_session_context(action_name="Changed") - def test_audit_context_closed_connection_set(self, audit_conn): - """set_audit_context raises InterfaceError on a closed connection.""" + def test_session_context_read_only_clear_via_none(self, audit_conn): + """Setting a read-only key to None raises ProgrammingError.""" + audit_conn.set_session_context(action_name="Locked", read_only=True) + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(action_name=None) + + def test_session_context_closed_connection_set(self, audit_conn): + """set_session_context raises InterfaceError on a closed connection.""" audit_conn.close() with pytest.raises(InterfaceError): - audit_conn.set_audit_context(application="X") + audit_conn.set_session_context(application_name="X") - def test_audit_context_closed_connection_get(self, audit_conn): - """get_audit_context raises InterfaceError on a closed connection.""" + def test_session_context_closed_connection_get(self, audit_conn): + """get_session_context raises InterfaceError on a closed connection.""" audit_conn.close() with pytest.raises(InterfaceError): - audit_conn.get_audit_context() + audit_conn.get_session_context() - def test_audit_context_key_too_long(self, audit_conn): + def test_session_context_key_max_length(self, audit_conn): + """A key at exactly 128 characters is accepted.""" + key = "k" * 128 + audit_conn.set_session_context(**{key: "val"}) + ctx = audit_conn.get_session_context() + assert ctx[key] == "val" + + def test_session_context_key_one_over_max(self, audit_conn): + """A key at 129 characters is rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(**{"k" * 129: "v"}) + + def test_session_context_key_too_long(self, audit_conn): """Keys longer than 128 characters are rejected.""" with pytest.raises(ProgrammingError): - audit_conn.set_audit_context(**{"x" * 200: "v"}) + audit_conn.set_session_context(**{"x" * 200: "v"}) + + def test_session_context_value_max_length(self, audit_conn): + """A value at exactly 4000 characters is accepted.""" + val = "v" * 4000 + audit_conn.set_session_context(user_id=val) + ctx = audit_conn.get_session_context() + assert ctx["user_id"] == val - def test_audit_context_value_too_long(self, audit_conn): - """Values longer than 4000 characters are rejected.""" + def test_session_context_value_one_over_max(self, audit_conn): + """A value at 4001 characters is rejected.""" with pytest.raises(ProgrammingError): - audit_conn.set_audit_context(user_id="v" * 4001) + audit_conn.set_session_context(user_id="v" * 4001) - def test_audit_context_non_string_value(self, audit_conn): + def test_session_context_non_string_value(self, audit_conn): """Non-string values are rejected with ProgrammingError.""" with pytest.raises(ProgrammingError): - audit_conn.set_audit_context(user_id=123) # type: ignore[arg-type] - - def test_get_audit_context_returns_copy(self, audit_conn): - """get_audit_context returns a copy, not the internal dict.""" - audit_conn.set_audit_context(application="Copy") - ctx = audit_conn.get_audit_context() - ctx["application_name"] = "Mutated" - assert audit_conn.get_audit_context()["application_name"] == "Copy" + audit_conn.set_session_context(user_id=123) # type: ignore[arg-type] + + def test_get_session_context_returns_fresh(self, audit_conn): + """get_session_context queries the server and returns a fresh dict each time.""" + audit_conn.set_session_context(application_name="Fresh") + ctx1 = audit_conn.get_session_context() + ctx2 = audit_conn.get_session_context() + assert ctx1 == ctx2 + assert ctx1 is not ctx2 # distinct dict objects + + def test_get_session_context_reflects_server(self, audit_conn): + """Getter fetches live values from the server, not stale cache.""" + audit_conn.set_session_context(user_id="original") + # Mutate the value directly via T-SQL (bypassing the Python API) + cursor = audit_conn.cursor() + try: + cursor.execute( + "EXEC sp_set_session_context @key=N'user_id', @value=N'mutated'" + ) + finally: + cursor.close() + # The getter should reflect the server-side mutation + ctx = audit_conn.get_session_context() + assert ctx["user_id"] == "mutated" + + # ---- clear_session_context tests ---- + + def test_clear_session_context_single_key(self, audit_conn): + """Clearing a single key removes it from the server.""" + audit_conn.set_session_context(user_id="1", module_name="Mod") + audit_conn.clear_session_context("user_id") + ctx = audit_conn.get_session_context() + assert "user_id" not in ctx + assert ctx["module_name"] == "Mod" + + def test_clear_session_context_multiple_keys(self, audit_conn): + """Clearing multiple keys removes them all.""" + audit_conn.set_session_context( + user_id="1", module_name="Mod", action_name="Act" + ) + audit_conn.clear_session_context("user_id", "action_name") + ctx = audit_conn.get_session_context() + assert "user_id" not in ctx + assert "action_name" not in ctx + assert ctx["module_name"] == "Mod" + + def test_clear_session_context_all(self, audit_conn): + """Calling with no args clears all non-read-only keys.""" + audit_conn.set_session_context(user_id="1", module_name="Mod") + audit_conn.clear_session_context() + assert audit_conn.get_session_context() == {} + + def test_clear_session_context_read_only_raises(self, audit_conn): + """Explicitly clearing a read-only key raises ProgrammingError.""" + audit_conn.set_session_context(user_id="locked", read_only=True) + with pytest.raises(ProgrammingError): + audit_conn.clear_session_context("user_id") + + def test_clear_session_context_all_skips_read_only(self, audit_conn): + """clear_session_context() without args skips read-only keys.""" + audit_conn.set_session_context(user_id="locked", read_only=True) + audit_conn.set_session_context(module_name="clearable") + audit_conn.clear_session_context() + ctx = audit_conn.get_session_context() + assert ctx["user_id"] == "locked" + assert "module_name" not in ctx + + def test_clear_session_context_closed(self, audit_conn): + """clear_session_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.clear_session_context("user_id") + + def test_clear_session_context_noop(self, audit_conn): + """Clearing when nothing has been set is a silent no-op.""" + audit_conn.clear_session_context() # should not raise + + # ---- Pool return tests ---- + + def test_pool_return_clears_context(self, conn_str): + """When pooling is enabled, close() clears session context server-side.""" + conn = connect(conn_str) + conn.set_session_context(application_name="PoolTest", module_name="Mod") + # Simulate pooling enabled + conn._pooling = True + conn.close() + # After close the cache should be empty + assert not getattr(conn, "_session_context", {}) + + def test_pool_return_skips_without_pooling(self, conn_str): + """Without pooling, close() does not attempt to clear session context.""" + conn = connect(conn_str) + conn.set_session_context(application_name="NoPools") + conn._pooling = False + conn.close() + # Cache is left as-is (object is closed, no pool reuse) + assert conn._session_context.get("application_name") == "NoPools" \ No newline at end of file