Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 285 additions & 1 deletion mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,271 @@ def clear_output_converters(self) -> None:
self._conn.clear_output_converters()
logger.info("Cleared all output converters")

# ---- Session Context API ----

_SESSION_KEY_MAX_LEN: int = 128
_SESSION_VALUE_MAX_LEN: int = 4000

def set_session_context(
self,
*,
read_only: bool = False,
**context: Optional[str],
) -> None:
"""
Set session-level metadata on the current connection.

This stores name-value pairs in the SQL Server session context via
``sp_set_session_context``, making them visible to:

* ``SESSION_CONTEXT()`` in T-SQL queries, triggers, and stored procedures
* Extended Events sessions that capture session context
* ``sys.dm_exec_sessions`` (for *application_name*)
* Audit specifications that reference session context

Only keys that are passed will be set. Calling this method again
merges new values with previously-set ones; to clear a key pass
``None`` as its value.

Well-known keys (optional, not enforced):
``application_name``, ``module_name``, ``action_name``, ``user_id``

Args:
read_only: If ``True``, the keys become read-only for the
remainder of the session — subsequent calls cannot change them.
**context: Key-value pairs to store in the session context.
Pass ``None`` as a value to clear a key.

Raises:
InterfaceError: If the connection is closed.
ProgrammingError: If a key or value exceeds length limits.
DatabaseError: If ``sp_set_session_context`` execution fails.
Comment thread
kapilsamant marked this conversation as resolved.

Example::

conn.set_session_context(
application_name="BillingAPI",
module_name="InvoiceProcessor",
action_name="GenerateInvoice",
user_id="123",
)
# Values are now readable in T-SQL:
# SELECT SESSION_CONTEXT(N'application_name')

# Clear a key:
conn.set_session_context(user_id=None)
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Cannot set session context on a closed connection",
)

if not context:
return # nothing to do

# Validate lengths
for key, value in context.items():
if not isinstance(key, str) or not key:
raise ProgrammingError(
driver_error="Invalid session context key",
ddbc_error="Session context key must be a non-empty string",
)
if len(key) > self._SESSION_KEY_MAX_LEN:
raise ProgrammingError(
driver_error="Session context key too long",
ddbc_error=(
f"Session context key exceeds {self._SESSION_KEY_MAX_LEN} characters"
),
)
if value is not None:
if not isinstance(value, str):
raise ProgrammingError(
driver_error="Invalid session context value",
ddbc_error="Session context values must be strings or None",
)
if len(value) > self._SESSION_VALUE_MAX_LEN:
raise ProgrammingError(
driver_error="Session context value too long",
ddbc_error=(
f"Session context value exceeds {self._SESSION_VALUE_MAX_LEN} characters"
),
)

# Initialize local cache if first call
if not hasattr(self, "_session_context"):
self._session_context: Dict[str, str] = {}
if not hasattr(self, "_session_context_read_only_keys"):
self._session_context_read_only_keys: set = set()

# Reject attempts to clear read-only keys via value=None
for key, value in context.items():
if value is None and key in self._session_context_read_only_keys:
raise ProgrammingError(
driver_error="Cannot clear read-only session context key",
ddbc_error=(
f"Session context key '{key}' was set with read_only=True "
"and cannot be cleared"
),
)

# Build a single batch of sp_set_session_context calls to execute
# in one round trip instead of N separate calls.
batch_parts: list[str] = []
batch_params: list = []
for key, value in context.items():
if read_only:
batch_parts.append(
"EXEC sp_set_session_context @key=?, @value=?, @read_only=1"
)
else:
batch_parts.append(
"EXEC sp_set_session_context @key=?, @value=?"
)
batch_params.append(key)
batch_params.append(value)

cursor = self.cursor()
try:
cursor.execute("; ".join(batch_parts), *batch_params)
finally:
cursor.close()

# Update local cache after successful execution
for key, value in context.items():
if value is None:
self._session_context.pop(key, None)
else:
self._session_context[key] = value
if read_only:
self._session_context_read_only_keys.add(key)
logger.debug("Set session context: %s", sanitize_user_input(key))

logger.info(
"Session context set with %d key(s): %s",
len(context),
", ".join(sanitize_user_input(k) for k in context),
)

def get_session_context(self) -> Dict[str, str]:
"""
Return the current session context for keys previously set via
:meth:`set_session_context`.

For each known key the driver queries the server with
``SELECT SESSION_CONTEXT(N'<key>')``, so the returned values
reflect any server-side mutations (e.g. by triggers or stored
procedures) that occurred after the initial ``set_session_context``
call.

Returns:
dict: A ``{key: value}`` mapping. Keys whose server-side
value is ``NULL`` are omitted.

Raises:
InterfaceError: If the connection is closed.
DatabaseError: If the server query fails.
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Cannot get session context on a closed connection",
)
if not hasattr(self, "_session_context") or not self._session_context:
return {}

# Query each known key from the server in one batch using
# parameterised queries to avoid SQL injection.
# SESSION_CONTEXT() requires nvarchar, so we CAST the parameter.
keys = list(self._session_context.keys())
select_parts = ["SELECT SESSION_CONTEXT(CAST(? AS nvarchar(128)))" for _ in keys]
cursor = self.cursor()
try:
cursor.execute("; ".join(select_parts), *keys)
result: Dict[str, str] = {}
for key in keys:
row = cursor.fetchone()
if row is not None and row[0] is not None:
result[key] = str(row[0])
# Advance to next result set for the next SELECT.
if not cursor.nextset():
break
return result
finally:
cursor.close()

def clear_session_context(self, *keys: str) -> None:
"""
Clear one or more session context keys by sending ``NULL`` to the server.

If no keys are provided, all non-read-only keys that were previously
set via :meth:`set_session_context` are cleared.

Args:
*keys: Key names to clear. If omitted, all clearable keys are
cleared.

Raises:
InterfaceError: If the connection is closed.
ProgrammingError: If a specified key was set with ``read_only=True``.
DatabaseError: If ``sp_set_session_context`` execution fails.

Example::

conn.clear_session_context("user_id") # clear one key
conn.clear_session_context("user_id", "action_name") # clear several
conn.clear_session_context() # clear all
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Cannot clear session context on a closed connection",
)
if not hasattr(self, "_session_context") or not self._session_context:
return

read_only_keys: set = getattr(self, "_session_context_read_only_keys", set())

if keys:
# Validate that none of the requested keys are read-only.
read_only_requested = read_only_keys & set(keys)
if read_only_requested:
raise ProgrammingError(
driver_error="Cannot clear read-only session context keys",
ddbc_error=(
"The following keys are read-only and cannot be cleared: "
+ ", ".join(sorted(read_only_requested))
),
)
to_clear = [k for k in keys if k in self._session_context]
else:
to_clear = [k for k in self._session_context if k not in read_only_keys]

if not to_clear:
return

# Build a batch to NULL-out every key in one round trip.
batch_parts: list[str] = []
batch_params: list = []
for key in to_clear:
batch_parts.append("EXEC sp_set_session_context @key=?, @value=NULL")
batch_params.append(key)

cursor = self.cursor()
try:
cursor.execute("; ".join(batch_parts), *batch_params)
finally:
cursor.close()

for key in to_clear:
self._session_context.pop(key, None)

logger.info(
"Cleared %d session context key(s): %s",
len(to_clear),
", ".join(sanitize_user_input(k) for k in to_clear),
)

def execute(self, sql: str, *args: Any) -> Cursor:
"""
Creates a new Cursor object, calls its execute method, and returns the new cursor.
Expand Down Expand Up @@ -1636,6 +1901,25 @@ def close(self) -> None:
# references
self._cursors.clear()

# If pooling is enabled, clear session context so the next consumer
# of this ODBC handle does not inherit leftover metadata.
if self._pooling and hasattr(self, "_session_context") and self._session_context:
read_only_keys = getattr(self, "_session_context_read_only_keys", set())
if read_only_keys & set(self._session_context):
logger.warning(
"Pooled connection has read-only session context keys that "
"cannot be cleared: %s. These values will persist for the "
"next consumer of this connection.",
", ".join(sorted(read_only_keys & set(self._session_context))),
)
try:
self.clear_session_context()
except Exception:
logger.warning(
"Failed to clear session context before pool return",
exc_info=True,
)

# Close the connection even if cursor cleanup had issues
try:
if self._conn:
Expand Down Expand Up @@ -1721,4 +2005,4 @@ def __del__(self) -> None:
self.close()
except Exception as e:
# Dont raise exceptions from __del__ to avoid issues during garbage collection
logger.warning(f"Error during connection cleanup: {e}")
logger.warning(f"Error during connection cleanup: {e}")
Loading