Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
[project]
name = "jaxborg"
version = "0.1.0"
description = "Add your description here"
description = "JAX reimplementation of the CC4 enterprise cyber-defense environment, plus a parity harness against CybORG."
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"chex",
"cyborg",
"distrax",
"flax",
"jaxmarl",
"mlflow>=3.10.0",
Expand All @@ -16,24 +15,22 @@ dependencies = [
"torch>=2.0",
]

# Pick exactly one — leaving both unspecified pulls jax in only via the
# transitive deps of jaxmarl/flax (CPU jaxlib), which silently masks the
# wrong-backend trap on GPU hosts. CI uses `--extra cpu`; GPU hosts use
# `--extra cuda`.
# Pick exactly one extra. Leaving both unspecified lets jax come in only
# via the transitive deps of jaxmarl/flax (CPU jaxlib), which silently
# masks the wrong-backend trap on GPU hosts. CI uses `--extra cpu`; GPU
# hosts use `--extra cuda`.
#
# `cuda` lists the cuda12 plugin packages directly rather than via
# `jax[cuda12]` because the latter pulls in `jax-cuda12-plugin[with-cuda]`,
# which requires nvidia-cudnn-cu12>=9.1 — collides with cyborg's pinned
# torch==2.2.0 (needs cudnn 8.9.2.26 in the cu12 namespace). Pinning to
# the cu13 namespace avoids the conflict.
# JAX is pinned to >=0.7.1 for the persistent-cache miscompile fix
# (jax-ml/jax#31733). 0.5.x silently returns stale ``.at[].set()`` updates
# inside ``lax.fori_loop`` / ``lax.cond`` on cache hit — corrupts the env
# step. 0.4.38 silently no-ops the cache (#25768). 0.7.1 is the line the
# upstream reporter explicitly confirmed fixed. Reaching it also required
# dropping ``distrax`` (its transitive ``tensorflow-probability<=0.25``
# imports a JAX-0.7-removed symbol); see
# ``src/jaxborg/policies/categorical.py`` for the local replacement.
[project.optional-dependencies]
cpu = ["jax[cpu]==0.4.38"]
cuda = [
"jax-cuda12-pjrt==0.4.38",
"jax-cuda12-plugin==0.4.38",
"nvidia-cublas>=13.2",
"nvidia-cudnn-cu13>=9.19",
]
cpu = ["jax[cpu]>=0.7.1"]
cuda = ["jax[cuda12]>=0.7.1"]

[project.scripts]
export-generated-topology = "jaxborg.scenarios.cc4.topology_cli:export_generated_main"
Expand All @@ -46,9 +43,27 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/jaxborg"]

# Overrides for upstream pins that lag the rest of the stack. Each line
# answers "what package pins this lower than we can use, and why is the
# bump safe":
# * matplotlib: cyborg pins ==3.8.2. Bumped for newer features and to
# match the rest of the matplotlib ecosystem.
# * jax / jaxlib: jaxmarl pins <=0.4.38 ("until we can test 0.5",
# commit 3ffa5b8f, upstream issue FLAIROx/JaxMARL#175). See the
# [project.optional-dependencies] note above for why we need 0.7+.
# * numpy: cyborg pins ==1.26.4. ml-dtypes (transitive via JAX 0.7+)
# requires numpy>=2.1; cyborg's runtime is numpy-2-compatible
# empirically (fast + slow suites green).
# * scipy: jaxmarl pins <=1.12 (numpy-1-era upper bound).
# * torch: cyborg pins ==2.2.0 (numpy-1-only).
[tool.uv]
override-dependencies = [
"matplotlib>=3.9",
"jax>=0.7.1",
"jaxlib>=0.7.1",
"numpy>=2.1,<2.4",
"scipy>=1.14",
"torch>=2.4",
]

[tool.uv.sources]
Expand All @@ -60,9 +75,6 @@ dev = [
"ruff",
"pytest-xdist",
]
# The previous `[dependency-groups.cuda]` is superseded by the `cuda` extra
# in `[project.optional-dependencies]` above. `jax[cuda12]` brings in the
# matching cuda12 pjrt + plugin so we don't have to pin them ourselves.

[tool.pytest.ini_options]
pythonpath = ["."]
Expand Down
4 changes: 2 additions & 2 deletions scripts/dev/parity/jax_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

import time

import distrax
import jax
import jax.numpy as jnp
import numpy as np

from dataclasses import replace

from jaxborg.actions.masking import compute_blue_action_mask
from jaxborg.policies.categorical import Categorical as JaxCategorical
from jaxborg.constants import COMPROMISE_PRIVILEGED, COMPROMISE_USER, NUM_BLUE_AGENTS, NUM_RED_AGENTS
from jaxborg.evaluation.jax_env_factory import make_jax_env
from jaxborg.scenarios.cc4.game_variants import CC4_STOCK
Expand Down Expand Up @@ -74,7 +74,7 @@ def _env_step(carry, _):
actions_arr = jnp.argmax(logits, axis=-1)
else:
act_keys = jax.random.split(_rng, NUM_BLUE_AGENTS)
actions_arr = jax.vmap(lambda lg, k: distrax.Categorical(logits=lg).sample(seed=k))(logits, act_keys)
actions_arr = jax.vmap(lambda lg, k: JaxCategorical(logits=lg).sample(seed=k))(logits, act_keys)

actions = {f"blue_{i}": actions_arr[i] for i in range(NUM_BLUE_AGENTS)}

Expand Down
5 changes: 3 additions & 2 deletions scripts/dev/parity/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

configure_runtime()

import distrax
import jax
import jax.numpy as jnp

from jaxborg.policies.categorical import Categorical as JaxCategorical


def load_checkpoint(path):
"""Load a recipe-driven JAX checkpoint via jax_runner (sidecar required)."""
Expand Down Expand Up @@ -48,7 +49,7 @@ def batched_step(obs_stack, mask_stack, _keys):
@jax.jit
def batched_step(obs_stack, mask_stack, keys):
logits = jax.vmap(_fwd)(obs_stack, mask_stack)
actions = jax.vmap(lambda lg, k: distrax.Categorical(logits=lg).sample(seed=k))(logits, keys)
actions = jax.vmap(lambda lg, k: JaxCategorical(logits=lg).sample(seed=k))(logits, keys)
return actions, logits

return batched_step
4 changes: 2 additions & 2 deletions scripts/eval/generate_cynex_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def get_action(self, obs, action_mask, deterministic=False):
def _load_jax_model(path: str):
"""Load a JAXborg JAX/Flax checkpoint. Returns (policy, params, policy_kind)."""
# Lazy imports so the script works without JAX when only using --model-pt
import distrax
import jax

from jaxborg.evaluation.jax_runner import load_jax_checkpoint
from jaxborg.policies.categorical import Categorical as JaxCategorical

policy, params, recipe = load_jax_checkpoint(path)
print(f"Loaded JAX checkpoint from {path} (arch={recipe['arch']['name']})")
Expand All @@ -110,7 +110,7 @@ def _fwd(o, m):
@jax.jit
def batched_step(obs_stack, mask_stack, keys):
logits = jax.vmap(_fwd)(obs_stack, mask_stack)
actions = jax.vmap(lambda lg, k: distrax.Categorical(logits=lg).sample(seed=k))(logits, keys)
actions = jax.vmap(lambda lg, k: JaxCategorical(logits=lg).sample(seed=k))(logits, keys)
return actions, logits

return batched_step, params
Expand Down
2 changes: 1 addition & 1 deletion src/jaxborg/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
frameworks (Flax functional vs torch nn.Module) are too different to wrap
in a single concrete type without leaking abstractions. The contract that
*does* unify them is: the JAX module's `__call__(obs, avail_actions)`
returns `(distrax.Categorical, value)`; the torch module's
returns `(policies.categorical.Categorical, value)`; the torch module's
`get_action_and_value(obs, mask, action=None)` returns
`(action, log_prob, entropy, value)`. Algorithm scripts on each backend
already speak their backend's framework — they don't try to be backend-
Expand Down
37 changes: 37 additions & 0 deletions src/jaxborg/policies/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Minimal categorical distribution for JAX policies.

Drop-in replacement for the subset of ``distrax.Categorical`` that jaxborg
uses (``.logits``, ``.sample(seed=)``, ``.log_prob(a)``, ``.entropy()``).
Implemented directly on top of ``jax.random.categorical`` and
``jax.nn.log_softmax`` so we don't drag in ``distrax`` (and through it
``tensorflow-probability``, which pins us below JAX 0.7 because of a
removed-symbol import — see pyproject.toml jax section).
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from flax import struct


@struct.dataclass
class Categorical:
"""Categorical distribution over the last axis of ``logits``.

Pytree-compatible (flax struct) so it can be returned from ``nn.Module``
forward passes and threaded through ``jit`` / ``vmap`` / ``scan``.
"""

logits: jax.Array

def sample(self, seed: jax.Array) -> jax.Array:
return jax.random.categorical(seed, self.logits, axis=-1)

def log_prob(self, action: jax.Array) -> jax.Array:
log_p = jax.nn.log_softmax(self.logits, axis=-1)
return jnp.take_along_axis(log_p, action[..., None], axis=-1).squeeze(-1)

def entropy(self) -> jax.Array:
log_p = jax.nn.log_softmax(self.logits, axis=-1)
return -jnp.sum(jnp.exp(log_p) * log_p, axis=-1)
4 changes: 2 additions & 2 deletions src/jaxborg/policies/separate_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from __future__ import annotations

import distrax
import flax.linen as nn
import jax.numpy as jnp
import numpy as np
Expand All @@ -19,6 +18,7 @@
from torch.distributions import Categorical

from .base import BUFFER_LAYOUT_FLAT
from .categorical import Categorical as JaxCategorical


class _ActorTrunk(nn.Module):
Expand All @@ -41,7 +41,7 @@ def __call__(self, x, avail_actions=None):
logits = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(h)
if avail_actions is not None:
logits = logits - (1 - avail_actions) * 1e10
return distrax.Categorical(logits=logits)
return JaxCategorical(logits=logits)


class _CriticTrunk(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions src/jaxborg/policies/shared_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

import distrax
import flax.linen as nn
import jax.numpy as jnp
import numpy as np
Expand All @@ -17,6 +16,7 @@
from torch.distributions import Categorical

from .base import BUFFER_LAYOUT_FLAT
from .categorical import Categorical as JaxCategorical


class _JaxSharedActorCritic(nn.Module):
Expand All @@ -41,7 +41,7 @@ def __call__(self, x, avail_actions=None):
logits = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(h)
if avail_actions is not None:
logits = logits - (1 - avail_actions) * 1e10
pi = distrax.Categorical(logits=logits)
pi = JaxCategorical(logits=logits)

value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(h)
value = jnp.squeeze(value, axis=-1)
Expand Down
7 changes: 4 additions & 3 deletions tests/differential/test_ppo_update_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def _jax_loss(params, network, obs, actions, mask, old_logp, mb_adv, mb_targets,
def test_forward_pass_parity():
"""torch and flax forward passes on identical params + inputs match.

Compares per-action log-probs and values rather than raw logits because
`distrax.Categorical.logits` returns log-softmax-normalized logits while
torch's logits are raw — the difference would mask real divergence.
Compares per-action log-probs and values rather than raw logits: torch
and flax may emit raw logits that differ by a constant shift along the
action axis (both still produce identical softmax and log_softmax), so
a raw-logits comparison would flag a non-divergence as a divergence.
"""
torch.manual_seed(0)
agent = TinyPPOAgent()
Expand Down
81 changes: 37 additions & 44 deletions tests/subsystems/test_fsm_red_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator

from jaxborg.actions.encoding import (
BLUE_SLEEP,
RED_EXPLOIT_HARAKA_START,
RED_EXPLOIT_HTTP_START,
RED_SLEEP,
)
from jaxborg.constants import GLOBAL_MAX_HOSTS, NUM_BLUE_AGENTS, NUM_RED_AGENTS, SERVICE_IDS
from jaxborg.constants import GLOBAL_MAX_HOSTS, NUM_RED_AGENTS, SERVICE_IDS
from jaxborg.scenarios.cc4.red_fsm import (
FSM_ACT_AGGRESSIVE_SCAN,
FSM_ACT_DISCOVER,
Expand All @@ -38,6 +37,7 @@
_pick_discover_subnet,
_pick_exploit_action,
determine_fsm_success,
fsm_red_apply_delayed_update,
fsm_red_get_action,
fsm_red_init_states,
fsm_red_process_session_removal,
Expand Down Expand Up @@ -116,50 +116,43 @@ def test_initial_active_red_agent_is_not_forced_to_sleep(self):
assert type(cyborg_action).__name__ != "Sleep"
assert int(red_actions[0]) != RED_SLEEP

def test_fsm_hidden_state_applies_after_completion_step(self, tmp_path):
"""FSM hidden state should update on the next decision step, not immediately on completion."""
from CybORG.Agents import EnterpriseGreenAgent
from CybORG.Agents.Wrappers import BlueFlatWrapper
from CybORG.Simulator.Actions import Sleep
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator
def test_fsm_hidden_state_applies_after_completion_step(self):
"""FSM hidden state updates on the next decision step, not immediately on completion.

from jaxborg.parity.fsm_red_env import FsmRedCC4Env
from jaxborg.scenarios.cc4.topology import build_const_from_cyborg, save_topology
CybORG's FiniteStateRedAgent reads new host_states when it next picks
an action, not on the simulation step that finishes a duration action.
jaxborg mirrors this with a two-stage commit: schedule_post_step_update
stages the next state into ``red_fsm_delayed_states``;
``apply_delayed_update`` runs at the start of the next step and copies
the staged state into ``fsm_host_states``.

scenario = EnterpriseScenarioGenerator(
blue_agent_class=SleepAgent,
green_agent_class=EnterpriseGreenAgent,
red_agent_class=FiniteStateRedAgent,
steps=500,
)
cyborg_env = BlueFlatWrapper(env=CybORG(scenario, "sim", seed=0), pad_spaces=True)
cyborg_env.reset()
cyborg_agent = cyborg_env.env.environment_controller.agent_interfaces["red_agent_0"].agent
topology_path = tmp_path / "cyborg_seed_0.npz"
save_topology(
build_const_from_cyborg(cyborg_env.env),
topology_path,
metadata={"source": "cyborg", "source_seed": 0},
)

jax_env = FsmRedCC4Env(num_steps=500, topology_path=topology_path)
key = jax.random.PRNGKey(0)
_, env_state = jax_env.reset(key)
start_host = int(env_state.const.red_start_hosts[0])

for expected in (FSM_U, FSM_U, FSM_R):
_, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents})
key, step_key = jax.random.split(key)
_, env_state, _, _, _ = jax_env.step(
step_key,
env_state,
{f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)},
)

cyborg_states = [info["state"] for info in cyborg_agent.host_states.values() if info.get("hostname")]
expected_cyborg = {FSM_U: "U", FSM_R: "R"}[expected]
assert cyborg_states == [expected_cyborg]
assert int(env_state.state.fsm_host_states[0, start_host]) == expected
This is a pure-state regression check — no env step, no RNG, no
CybORG comparison — so it doesn't drift when an unrelated detail
(PRNG layout, CybORG seed mapping, etc.) changes.
"""
host = 5
state = create_initial_state()
state = state.replace(
fsm_host_states=state.fsm_host_states.at[0, host].set(FSM_S),
red_fsm_delayed_states=state.fsm_host_states.at[0, host].set(FSM_U),
red_fsm_delayed_pending=jnp.bool_(True),
)

# Same step as the action that completed: visible FSM state must not
# have changed yet, even though the next state is staged.
assert int(state.fsm_host_states[0, host]) == FSM_S

# Next decision step: apply_delayed_update commits the staged state
# and clears the pending flag.
applied = fsm_red_apply_delayed_update(state)
assert int(applied.fsm_host_states[0, host]) == FSM_U
assert not bool(applied.red_fsm_delayed_pending)

# And without a pending update, apply_delayed_update is a no-op —
# the visible FSM state stays put.
no_pending = state.replace(red_fsm_delayed_pending=jnp.bool_(False))
unchanged = fsm_red_apply_delayed_update(no_pending)
assert int(unchanged.fsm_host_states[0, host]) == FSM_S


class TestFsmUpdateState:
Expand Down
Loading
Loading