diff --git a/.gitignore b/.gitignore index e061e9d..170b983 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ mlruns/ .jax_cache/ slurm-*.out scripts/dev/catalog_status.json -outputs/* \ No newline at end of file +outputs/* +logs/ \ No newline at end of file diff --git a/recipes/cec_phase6_C00.yaml b/recipes/cec_phase6_C00.yaml new file mode 100644 index 0000000..bc1baf4 --- /dev/null +++ b/recipes/cec_phase6_C00.yaml @@ -0,0 +1,57 @@ +meta: + name: cec_phase6_C00 + source: "Phase 6 plan — plans/jax/cc4/cec/cec-phase6-plan.md (Test 2 — control)" + added: 2026-05-09 + notes: | + Phase 6 Test 2 — arm C00, the canonical-config control. + Forks Matched-Training v2 hyperparameters from recipes/default.yaml. + + No env-diversity banks active: fixed canonical topology, fixed (1,1,1) + mission profile, fixed phase boundaries [0, 167, 333], fixed canonical + phase_rewards. CEC-faithful: training partner is fixed (cc4_stock + variant's red = fsm). Eval-time held-out partner sweep happens via + eval_recipe.py's --eval-red flag. + + Paired against C11 (full env-diversity cocktail) — paired delta on each + held-out red is the Phase 6 headline DV. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 3000000 + variant: cc4_stock + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_C00_10M.yaml b/recipes/cec_phase6_C00_10M.yaml new file mode 100644 index 0000000..2735db0 --- /dev/null +++ b/recipes/cec_phase6_C00_10M.yaml @@ -0,0 +1,49 @@ +meta: + name: cec_phase6_C00_10M + source: "Phase 6 plan — Option B (10M timesteps × 3 seeds, control replicate)" + added: 2026-05-11 + notes: | + Canonical-config control at 10M timesteps. Identical to cec_phase6_C00 + except total_timesteps. Paired against cec_phase6_C11_10M and each + single-axis ablation arm. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 10000000 + variant: cc4_stock + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_C11.yaml b/recipes/cec_phase6_C11.yaml new file mode 100644 index 0000000..13200bd --- /dev/null +++ b/recipes/cec_phase6_C11.yaml @@ -0,0 +1,102 @@ +meta: + name: cec_phase6_C11 + source: "Phase 6 plan — plans/jax/cc4/cec/cec-phase6-plan.md (Test 2 — full env-diversity cocktail)" + added: 2026-05-09 + notes: | + Phase 6 Test 2 — arm C11, the full env-diversity cocktail. + Forks Matched-Training v2 hyperparameters from recipes/default.yaml. + + Per-reset, FOUR independent banks resample: + 1. topology_bank (16 shapes) — router adjacency, op-zone sizing + 2. mission_bank (anti-corr 4-entry) — (LWF, ASF, RIA) multipliers + 3. phase_boundary_bank (4 entries) — when phase 1/2 transitions fire + 4. phase_rewards_bank (true → 6) — crown-jewel rotation across subnets + + The 2×2 factorial (C00/C01/C10/C11) was collapsed after Test 1 v2 found + that the σ-ratio gate was policy-mediated: an untrained-on-bank policy + can't generate distributional spread under bank variation. C11 is the + "all four banks active" arm; C00 is the canonical-config control. Per-axis + ablations are deferred to Phase 7 if C11 lands. + + CEC-faithful: training partner fixed (cc4_stock variant's red = fsm); + held-out red sweep at eval time via eval_recipe.py --eval-red. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 3000000 + variant: cc4_stock + topology_bank: + - scripts/dev/topology_bank/shape_00.snapshot.npz + - scripts/dev/topology_bank/shape_01.snapshot.npz + - scripts/dev/topology_bank/shape_02.snapshot.npz + - scripts/dev/topology_bank/shape_03.snapshot.npz + - scripts/dev/topology_bank/shape_04.snapshot.npz + - scripts/dev/topology_bank/shape_05.snapshot.npz + - scripts/dev/topology_bank/shape_06.snapshot.npz + - scripts/dev/topology_bank/shape_07.snapshot.npz + - scripts/dev/topology_bank/shape_08.snapshot.npz + - scripts/dev/topology_bank/shape_09.snapshot.npz + - scripts/dev/topology_bank/shape_10.snapshot.npz + - scripts/dev/topology_bank/shape_11.snapshot.npz + - scripts/dev/topology_bank/shape_12.snapshot.npz + - scripts/dev/topology_bank/shape_13.snapshot.npz + - scripts/dev/topology_bank/shape_14.snapshot.npz + - scripts/dev/topology_bank/shape_15.snapshot.npz + # Anti-correlated bank — every non-baseline entry boosts 2 of 3 components + # so a "boost the loud one" memorization fails. Disambiguates "diversity + # itself helps" from "loud reward signal helps" (the Test 1 σ critique). + mission_bank: + - [1, 1, 1] + - [3, 3, 1] + - [1, 3, 3] + - [3, 1, 3] + mission_bank_amplify: 1.0 + # P2 — per-reset phase-boundary jitter. Phase transitions, allow-list flips, + # and per-phase reward emphasis all reindex against the sampled split, so + # "deploy decoys at step 167" memorization breaks. + phase_boundary_bank: + - [0, 167, 333] # canonical CC4 split + - [0, 100, 300] # short setup, balanced mid+late + - [0, 200, 400] # long setup, short late + - [0, 150, 250] # short mid-phase, late starts at 250 + # P3 — crown-jewel rotation. Per-reset sample of which subnet is high-value + # in which phase (OPS_A↔OPS_B swap, ADMIN priority, OFFICE priority, both- + # OPS, full rotation). Forces the policy to read state instead of memorizing + # subnet indices. true → use the canonical 6-entry bank from topology_numpy. + phase_rewards_bank: true + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_C11_10M.yaml b/recipes/cec_phase6_C11_10M.yaml new file mode 100644 index 0000000..ee67f88 --- /dev/null +++ b/recipes/cec_phase6_C11_10M.yaml @@ -0,0 +1,80 @@ +meta: + name: cec_phase6_C11_10M + source: "Phase 6 plan — Option B (10M timesteps × 3 seeds, full cocktail replicate)" + added: 2026-05-11 + notes: | + Full env-diversity cocktail at 10M timesteps — tests whether the cocktail + needs more training to reveal a ZSC benefit. At 3M timesteps Test 2 + REFUTED all four held-out reds (Δs +20..+102, sign flips). Option B + extends training by ~3.3× to see if the cocktail converges on a + transferable representation. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 10000000 + variant: cc4_stock + topology_bank: + - scripts/dev/topology_bank/shape_00.snapshot.npz + - scripts/dev/topology_bank/shape_01.snapshot.npz + - scripts/dev/topology_bank/shape_02.snapshot.npz + - scripts/dev/topology_bank/shape_03.snapshot.npz + - scripts/dev/topology_bank/shape_04.snapshot.npz + - scripts/dev/topology_bank/shape_05.snapshot.npz + - scripts/dev/topology_bank/shape_06.snapshot.npz + - scripts/dev/topology_bank/shape_07.snapshot.npz + - scripts/dev/topology_bank/shape_08.snapshot.npz + - scripts/dev/topology_bank/shape_09.snapshot.npz + - scripts/dev/topology_bank/shape_10.snapshot.npz + - scripts/dev/topology_bank/shape_11.snapshot.npz + - scripts/dev/topology_bank/shape_12.snapshot.npz + - scripts/dev/topology_bank/shape_13.snapshot.npz + - scripts/dev/topology_bank/shape_14.snapshot.npz + - scripts/dev/topology_bank/shape_15.snapshot.npz + mission_bank: + - [1, 1, 1] + - [3, 3, 1] + - [1, 3, 3] + - [3, 1, 3] + mission_bank_amplify: 1.0 + phase_boundary_bank: + - [0, 167, 333] + - [0, 100, 300] + - [0, 200, 400] + - [0, 150, 250] + phase_rewards_bank: true + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_cjewel_10M.yaml b/recipes/cec_phase6_cjewel_10M.yaml new file mode 100644 index 0000000..75dc3ec --- /dev/null +++ b/recipes/cec_phase6_cjewel_10M.yaml @@ -0,0 +1,50 @@ +meta: + name: cec_phase6_cjewel_10M + source: "Phase 6 plan — cocktail ablation (crown-jewel-rotation-only axis)" + added: 2026-05-11 + notes: | + Single-axis ablation: PHASE-REWARDS bank (crown-jewel rotation) active, + other banks OFF. Tests whether crown-jewel-location variation alone + drives ZSC. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 10000000 + variant: cc4_stock + phase_rewards_bank: true + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_miss_10M.yaml b/recipes/cec_phase6_miss_10M.yaml new file mode 100644 index 0000000..b38b917 --- /dev/null +++ b/recipes/cec_phase6_miss_10M.yaml @@ -0,0 +1,54 @@ +meta: + name: cec_phase6_miss_10M + source: "Phase 6 plan — cocktail ablation (mission-bank-only axis)" + added: 2026-05-11 + notes: | + Single-axis ablation: MISSION bank (anti-correlated 4-entry) active, + other banks OFF. Tests whether mission variation alone drives ZSC. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 10000000 + variant: cc4_stock + mission_bank: + - [1, 1, 1] + - [3, 3, 1] + - [1, 3, 3] + - [3, 1, 3] + mission_bank_amplify: 1.0 + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_pbound_10M.yaml b/recipes/cec_phase6_pbound_10M.yaml new file mode 100644 index 0000000..8a907a8 --- /dev/null +++ b/recipes/cec_phase6_pbound_10M.yaml @@ -0,0 +1,54 @@ +meta: + name: cec_phase6_pbound_10M + source: "Phase 6 plan — cocktail ablation (phase-boundary-only axis)" + added: 2026-05-11 + notes: | + Single-axis ablation: PHASE-BOUNDARY bank active (per-reset jitter on + phase 1/2 transitions), other banks OFF. Tests whether phase-timing + variation alone drives ZSC. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 10000000 + variant: cc4_stock + phase_boundary_bank: + - [0, 167, 333] + - [0, 100, 300] + - [0, 200, 400] + - [0, 150, 250] + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/recipes/cec_phase6_topo_10M.yaml b/recipes/cec_phase6_topo_10M.yaml new file mode 100644 index 0000000..67d6051 --- /dev/null +++ b/recipes/cec_phase6_topo_10M.yaml @@ -0,0 +1,65 @@ +meta: + name: cec_phase6_topo_10M + source: "Phase 6 plan — cocktail ablation (topology-only axis)" + added: 2026-05-11 + notes: | + Single-axis ablation: TOPOLOGY bank active, mission/phase-boundary/phase- + rewards banks OFF. Tests whether topology variation alone drives ZSC. + +algorithm: ippo + +core: + lr: 3.0e-4 + gamma: 0.99 + gae_lambda: 0.95 + clip_eps: 0.2 + vf_coef: 0.5 + max_grad_norm: 0.5 + ent_coef: 0.01 + norm_rewards: true + +arch: + name: shared + hidden_dim: 256 + hidden_layers: 2 + activation: tanh + +train: + episode_length: 500 + buffer_size: 5120000 + minibatch_size: 320000 + sgd_iterations: 64 + total_timesteps: 10000000 + variant: cc4_stock + topology_bank: + - scripts/dev/topology_bank/shape_00.snapshot.npz + - scripts/dev/topology_bank/shape_01.snapshot.npz + - scripts/dev/topology_bank/shape_02.snapshot.npz + - scripts/dev/topology_bank/shape_03.snapshot.npz + - scripts/dev/topology_bank/shape_04.snapshot.npz + - scripts/dev/topology_bank/shape_05.snapshot.npz + - scripts/dev/topology_bank/shape_06.snapshot.npz + - scripts/dev/topology_bank/shape_07.snapshot.npz + - scripts/dev/topology_bank/shape_08.snapshot.npz + - scripts/dev/topology_bank/shape_09.snapshot.npz + - scripts/dev/topology_bank/shape_10.snapshot.npz + - scripts/dev/topology_bank/shape_11.snapshot.npz + - scripts/dev/topology_bank/shape_12.snapshot.npz + - scripts/dev/topology_bank/shape_13.snapshot.npz + - scripts/dev/topology_bank/shape_14.snapshot.npz + - scripts/dev/topology_bank/shape_15.snapshot.npz + +eval: + variant: cc4_stock + +jax: + num_envs: 48 + num_minibatches: 16 + update_epochs: 4 + checkpoint_every_updates: 50 + +cleanrl: + num_envs: 48 + rollout_length: 500 + num_rollouts_per_update: 1 + num_minibatches: 16 diff --git a/scripts/dev/build_topology_bank.py b/scripts/dev/build_topology_bank.py new file mode 100644 index 0000000..84daaea --- /dev/null +++ b/scripts/dev/build_topology_bank.py @@ -0,0 +1,278 @@ +"""Build a bank of CC4 topology snapshots covering structural variation. + +Phase 6 stream S1 — produces the "kitchen layouts" analog: 16 (or N) valid +topology snapshots that vary along three axes: + +1. Router adjacency — perturbations of `_ROUTER_LINKS` post-applied to + `data_links` (e.g. add an op-zone-A ↔ op-zone-B router cross-link, drop + the office ↔ admin link). +2. Subnet sizing — per-zone op-server floors targeting totals + `op_zone_servers ∈ {3, 6, 9}` (multiples of 3 so the AUTH/DB/WEB role + assignment splits the candidate pool evenly), plus seed-driven user + host counts. +3. Cross-segment allow-list — perturbations of `allowed_subnet_pairs` + (zero-out one phase pair on certain shapes). + +Every emitted snapshot is loaded via `load_topology` and validated under +`_validate_resilience_topology(CIA_RESILIENCE, [p])`; bad shapes abort +the run. + +Usage:: + + python scripts/dev/build_topology_bank.py \ + --out-dir scripts/dev/topology_bank/ --count 16 --seed 0 +""" + +from __future__ import annotations + +import argparse +import hashlib +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +from jaxborg.constants import GLOBAL_MAX_HOSTS, SUBNET_IDS +from jaxborg.evaluation.jax_env_factory import _validate_resilience_topology +from jaxborg.scenarios.cc4.game_variants import CIA_RESILIENCE +from jaxborg.scenarios.cc4.topology import build_topology, load_topology, save_topology +from jaxborg.state import SimulatorConst + +# 4 perturbation patterns × 4 base seeds = 16 shapes by default. +# Patterns vary along multiple axes simultaneously so that even a small +# bank covers router-adjacency, sizing, and allow-list variation. +# op_zone_min_servers is (alpha-zone-A floor, alpha-zone-B floor). The +# total op-zone server count = a_floor + b_floor; targets are +# {3, 6, 9} — all multiples of 3 so AUTH/DB/WEB role candidates split +# evenly across the bank. +_PATTERNS = ( + { + "name": "P0_baseline_total6", + "op_zone_min_servers": (3, 3), # 6 total + "router_perturbation": "none", + "allowlist_perturbation": "none", + }, + { + "name": "P1_total3_dropOA", + "op_zone_min_servers": (1, 2), # 3 total + "router_perturbation": "drop_office_admin", + "allowlist_perturbation": "none", + }, + { + "name": "P2_total9_OPxlink", + "op_zone_min_servers": (4, 5), # 9 total + "router_perturbation": "add_opzone_xlink", + "allowlist_perturbation": "none", + }, + { + "name": "P3_total6_restrictpairs", + "op_zone_min_servers": (3, 3), # 6 total + "router_perturbation": "none", + "allowlist_perturbation": "drop_phase1_contractor_admin", + }, +) + + +def _router_host_idx_per_subnet(const: SimulatorConst) -> dict[int, int]: + """Return {subnet_id: host_idx_of_router} for active subnets.""" + host_subnet = np.asarray(const.host_subnet) + host_is_router = np.asarray(const.host_is_router) + host_active = np.asarray(const.host_active) + out: dict[int, int] = {} + for h in range(GLOBAL_MAX_HOSTS): + if host_active[h] and host_is_router[h]: + sid = int(host_subnet[h]) + # Only one router per subnet — first wins. + out.setdefault(sid, h) + return out + + +def _apply_router_perturbation(const: SimulatorConst, kind: str) -> SimulatorConst: + if kind == "none": + return const + + routers = _router_host_idx_per_subnet(const) + data_links = np.asarray(const.data_links).copy() + + if kind == "add_opzone_xlink": + # Wire the two operational-zone routers directly to each other. + a = routers.get(SUBNET_IDS["OPERATIONAL_ZONE_A"]) + b = routers.get(SUBNET_IDS["OPERATIONAL_ZONE_B"]) + if a is not None and b is not None: + data_links[a, b] = True + data_links[b, a] = True + elif kind == "drop_office_admin": + # Cut the office ↔ admin secondary path; both still reach PAZ. + # In stock _ROUTER_LINKS, office and admin only connect via + # PUBLIC_ACCESS_ZONE — there's no direct link to drop. To still + # produce a meaningful router-adjacency perturbation, drop the + # OFFICE_NETWORK ↔ PUBLIC_ACCESS_ZONE router link instead, which + # isolates OFFICE_NETWORK from the rest of the routing fabric. + office = routers.get(SUBNET_IDS["OFFICE_NETWORK"]) + paz = routers.get(SUBNET_IDS["PUBLIC_ACCESS_ZONE"]) + if office is not None and paz is not None: + data_links[office, paz] = False + data_links[paz, office] = False + else: + raise ValueError(f"unknown router_perturbation: {kind!r}") + + return const.replace(data_links=jnp.asarray(data_links)) + + +def _apply_allowlist_perturbation(const: SimulatorConst, kind: str) -> SimulatorConst: + if kind == "none": + return const + + pairs = np.asarray(const.allowed_subnet_pairs).copy() + + if kind == "drop_phase1_contractor_admin": + # Zero a single phase-1 cross-segment pair so the allow-list bank + # actually varies. CONTRACTOR_NETWORK ↔ ADMIN_NETWORK is present + # in the stock phase-1 allow-list, so dropping it produces a + # measurable reduction in the bank summary table. + s_con = SUBNET_IDS["CONTRACTOR_NETWORK"] + s_adm = SUBNET_IDS["ADMIN_NETWORK"] + pairs[1, s_con, s_adm] = False + pairs[1, s_adm, s_con] = False + else: + raise ValueError(f"unknown allowlist_perturbation: {kind!r}") + + return const.replace(allowed_subnet_pairs=jnp.asarray(pairs)) + + +def _data_link_router_hash(const: SimulatorConst) -> str: + """Hash of the inter-subnet router-router data_links (adjacency fingerprint).""" + routers = _router_host_idx_per_subnet(const) + sids = sorted(routers.keys()) + data_links = np.asarray(const.data_links) + bits = [] + for i, si in enumerate(sids): + for sj in sids[i + 1 :]: + ri = routers[si] + rj = routers[sj] + bits.append("1" if data_links[ri, rj] else "0") + return hashlib.sha1("".join(bits).encode()).hexdigest()[:8] + + +def _allowed_pair_count(const: SimulatorConst) -> int: + return int(np.asarray(const.allowed_subnet_pairs).sum()) + + +def _summarize(const: SimulatorConst) -> dict: + host_subnet = np.asarray(const.host_subnet) + host_active = np.asarray(const.host_active) + host_is_user = np.asarray(const.host_is_user) + host_is_server = np.asarray(const.host_is_server) + + user_subnets = sorted({int(host_subnet[h]) for h in range(GLOBAL_MAX_HOSTS) if host_active[h] and host_is_user[h]}) + server_subnets = sorted( + {int(host_subnet[h]) for h in range(GLOBAL_MAX_HOSTS) if host_active[h] and host_is_server[h]} + ) + + op_zone_a = int(SUBNET_IDS["OPERATIONAL_ZONE_A"]) + op_zone_b = int(SUBNET_IDS["OPERATIONAL_ZONE_B"]) + op_zone_servers = sum( + 1 + for h in range(GLOBAL_MAX_HOSTS) + if host_active[h] and host_is_server[h] and host_subnet[h] in (op_zone_a, op_zone_b) + ) + + return { + "router_hash": _data_link_router_hash(const), + "num_user_subnets": len(user_subnets), + "num_server_subnets": len(server_subnets), + "op_zone_servers": op_zone_servers, + "allowed_pair_count": _allowed_pair_count(const), + "num_hosts": int(np.asarray(const.num_hosts)), + } + + +def build_one(seed: int, pattern: dict) -> SimulatorConst: + const = build_topology( + jax.random.PRNGKey(seed), + op_zone_min_servers=pattern["op_zone_min_servers"], + ) + const = _apply_router_perturbation(const, pattern["router_perturbation"]) + const = _apply_allowlist_perturbation(const, pattern["allowlist_perturbation"]) + return const + + +def build_bank(out_dir: Path, count: int, base_seed: int) -> list[Path]: + out_dir.mkdir(parents=True, exist_ok=True) + paths: list[Path] = [] + + n_patterns = len(_PATTERNS) + rows: list[tuple[str, dict, int]] = [] + + for i in range(count): + pattern = _PATTERNS[i % n_patterns] + seed = base_seed + (i // n_patterns) + const = build_one(seed, pattern) + # ``np.savez_compressed`` appends ``.npz`` if not present; use the + # extension explicitly so the on-disk filename matches what we + # return and what callers (recipe + tests) reference. + path = out_dir / f"shape_{i:02d}.snapshot.npz" + save_topology( + const, + path, + metadata={ + "source": "phase6_topology_bank", + "shape_index": i, + "pattern_name": pattern["name"], + "source_seed": seed, + "op_zone_min_servers": pattern["op_zone_min_servers"], + "router_perturbation": pattern["router_perturbation"], + "allowlist_perturbation": pattern["allowlist_perturbation"], + }, + ) + paths.append(path) + + # Validate that every snapshot loads cleanly and passes the + # resilience-topology check for the strictest variant we care about. + _validate_resilience_topology(CIA_RESILIENCE, paths) + for p in paths: + # Round-trip load to confirm the on-disk snapshot is parseable. + const = load_topology(p) + rows.append((p.name, _summarize(const), int(p.stat().st_size))) + + _print_summary_table(rows) + return paths + + +def _print_summary_table(rows: list[tuple[str, dict, int]]) -> None: + header = ( + f"{'shape':<22}{'router_hash':<13}{'#user_sub':<11}{'#srv_sub':<10}" + f"{'#opzone_srv':<13}{'#allowpair':<12}{'#hosts':<8}{'bytes':<10}" + ) + print(header) + print("-" * len(header)) + for name, s, nbytes in rows: + print( + f"{name:<22}{s['router_hash']:<13}{s['num_user_subnets']:<11}" + f"{s['num_server_subnets']:<10}{s['op_zone_servers']:<13}" + f"{s['allowed_pair_count']:<12}{s['num_hosts']:<8}{nbytes:<10}" + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Build a bank of CC4 topology snapshots for Phase 6 Axis A.") + parser.add_argument( + "--out-dir", + type=Path, + default=Path("scripts/dev/topology_bank"), + help="Directory to emit shape_NN.snapshot files into.", + ) + parser.add_argument("--count", type=int, default=16, help="Number of snapshots to emit.") + parser.add_argument("--seed", type=int, default=0, help="Base PRNG seed.") + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + paths = build_bank(args.out_dir, count=args.count, base_seed=args.seed) + print(f"\nEmitted {len(paths)} snapshots to {args.out_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/dev/cec_phase5_diversity_spike.py b/scripts/dev/cec_phase5_diversity_spike.py new file mode 100644 index 0000000..b767ab1 --- /dev/null +++ b/scripts/dev/cec_phase5_diversity_spike.py @@ -0,0 +1,112 @@ +"""Phase 5 Test 1 — heuristic-spread diversity audit (no training). + +Run a deterministic blue policy (sleep) across two env distributions: + A. ENV-FIXED: reset() with the same topology key every episode; only + red/green stochasticity varies between rollouts. + B. ENV-DIVERSE: reset() with a fresh topology key each episode; what + `topology_mode='generative'` produces by default on this + branch. + +Per-episode total reward is recorded for each. Pre-registered verdict: +the env-diverse distribution is "meaningful" if its per-episode reward +std is at least 1.5x the env-fixed std (i.e., topology variation +contributes meaningful spread on top of red/green noise). + +Heuristic-spread spike — diagnoses whether what the cec branch already +varies (host counts, services, PIDs) actually changes the problem the +agent faces, before we either trust it or port phase-reward / topology +banks from the diversity branch. +""" + +from __future__ import annotations + +import os +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np + +from jaxborg.parity.fsm_red_env import make_fsm_red_env + + +N_EPISODES = int(os.environ.get("CEC_SPIKE_EPISODES", "32")) +N_STEPS = int(os.environ.get("CEC_SPIKE_STEPS", "500")) + + +def make_rollout_fn(env): + sleep = {f"blue_{i}": jnp.int32(0) for i in range(5)} + + def rollout(init_key: jax.Array, rollout_key: jax.Array) -> jax.Array: + _, state = env.reset(init_key) + + def step_fn(carry, k): + st, total = carry + _, st, r, _, _ = env.step(k, st, sleep) + return (st, total + r["blue_0"]), None + + step_keys = jax.random.split(rollout_key, N_STEPS) + (_, total), _ = jax.lax.scan(step_fn, (state, jnp.float32(0.0)), step_keys) + return total + + return rollout + + +def main() -> int: + print(f"=== CEC Phase 5 Test 1 — diversity spike (N={N_EPISODES} eps, sleep policy) ===") + print(f"JAX backend: {jax.default_backend()} ({jax.devices()})") + + env = make_fsm_red_env(num_steps=N_STEPS, topology_mode="generative") + rollout = make_rollout_fn(env) + rollout_batched = jax.jit(jax.vmap(rollout, in_axes=(0, 0))) + + rng = jax.random.PRNGKey(0) + rng_split = jax.random.split(rng, 3) + init_keys_diverse = jax.random.split(rng_split[0], N_EPISODES) + rollout_keys_a = jax.random.split(rng_split[1], N_EPISODES) + rollout_keys_b = jax.random.split(rng_split[2], N_EPISODES) + + # SAME topology key every episode (env-fixed): broadcast a single key. + fixed_init_key = jax.random.PRNGKey(42) + init_keys_fixed = jnp.broadcast_to(fixed_init_key[None, :], (N_EPISODES, fixed_init_key.shape[0])) + + print("\nCompiling + running ENV-FIXED scan...") + t0 = time.perf_counter() + rewards_fixed = rollout_batched(init_keys_fixed, rollout_keys_a) + rewards_fixed.block_until_ready() + t_fixed = time.perf_counter() - t0 + print(f" wall: {t_fixed:.1f}s") + + print("\nRunning ENV-DIVERSE scan (re-uses compiled cache)...") + t0 = time.perf_counter() + rewards_diverse = rollout_batched(init_keys_diverse, rollout_keys_b) + rewards_diverse.block_until_ready() + t_div = time.perf_counter() - t0 + print(f" wall: {t_div:.1f}s") + + rf = np.asarray(rewards_fixed) + rd = np.asarray(rewards_diverse) + + print("\n--- Per-episode reward summary ---") + fmt = "{:14s} mean={:8.1f} std={:6.1f} min={:8.1f} max={:8.1f} range={:7.1f}" + print(fmt.format("ENV-FIXED", rf.mean(), rf.std(), rf.min(), rf.max(), rf.max() - rf.min())) + print(fmt.format("ENV-DIVERSE", rd.mean(), rd.std(), rd.min(), rd.max(), rd.max() - rd.min())) + + ratio = rd.std() / max(rf.std(), 1e-9) + print(f"\nstd ratio (diverse / fixed): {ratio:.2f}") + + # Pre-registered threshold: 1.5x. Below = topology variation is not + # producing meaningful spread on top of red/green noise. + THRESHOLD = 1.5 + if ratio >= THRESHOLD: + print(f"VERDICT: PASS (ratio {ratio:.2f} >= {THRESHOLD}) — env-diverse has real spread.") + return 0 + else: + print(f"VERDICT: FAIL (ratio {ratio:.2f} < {THRESHOLD}) — env-diverse spread is not meaningfully larger than red/green noise.") + print(" → port phase-reward / topology banks from diversity branch before any Phase 5 training.") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/dev/cec_phase5_diversity_spike_trained.py b/scripts/dev/cec_phase5_diversity_spike_trained.py new file mode 100644 index 0000000..3efd9ed --- /dev/null +++ b/scripts/dev/cec_phase5_diversity_spike_trained.py @@ -0,0 +1,136 @@ +"""Phase 5 Test 1 (rerun) — diversity audit with a TRAINED policy. + +Same protocol as ``cec_phase5_diversity_spike.py`` (sleep variant), but +the blue agents are driven by a 3M-step matched-v2 IPPO checkpoint +instead of always selecting Sleep. Sleep showed a 1.15 std ratio (FAIL); +a responsive policy ought to react to topology variation, so its reward +should be more sensitive to the fixed-vs-diverse env distinction than a +pure noise-floor policy is. + +Default checkpoint: ``default_seed42`` (no resilience red, fixed env at +training time) — the parity-gate baseline. Override via +``CEC_SPIKE_CHECKPOINT``. +""" + +from __future__ import annotations + +import os +import sys +import time +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +from jaxborg.evaluation.jax_runner import load_jax_checkpoint +from jaxborg.parity.fsm_red_env import make_fsm_red_env + +DEFAULT_CKPT = ( + "/home/local/KHQ/paul.elliott/src/cyber/jaxborg-exp/ippo_jax/" + "default_seed42/model_default_seed42.safetensors" +) + +N_EPISODES = int(os.environ.get("CEC_SPIKE_EPISODES", "128")) +N_STEPS = int(os.environ.get("CEC_SPIKE_STEPS", "500")) +CKPT = Path(os.environ.get("CEC_SPIKE_CHECKPOINT", DEFAULT_CKPT)) + + +def make_rollout_fn(env, policy, params): + agents = [f"blue_{i}" for i in range(5)] + + def policy_step(obs_dict, mask_dict): + # Stack across 5 agents, vmap policy.apply. + obs_stack = jnp.stack([obs_dict[a] for a in agents]) + mask_stack = jnp.stack([mask_dict[a] for a in agents]) + + def fwd(o, m): + pi, _ = policy.apply(params, o, m) + return pi.logits + + logits = jax.vmap(fwd)(obs_stack, mask_stack) + actions = jnp.argmax(logits, axis=-1) + return {a: actions[i] for i, a in enumerate(agents)} + + def rollout(init_key: jax.Array, rollout_key: jax.Array) -> jax.Array: + obs0, state0 = env.reset(init_key) + mask0 = env.get_avail_actions(state0) + + def step_fn(carry, k): + st, obs, mask, total = carry + actions = policy_step(obs, mask) + new_obs, new_st, r, _, _ = env.step(k, st, actions) + new_mask = env.get_avail_actions(new_st) + return (new_st, new_obs, new_mask, total + r["blue_0"]), None + + step_keys = jax.random.split(rollout_key, N_STEPS) + (_, _, _, total), _ = jax.lax.scan( + step_fn, (state0, obs0, mask0, jnp.float32(0.0)), step_keys + ) + return total + + return rollout + + +def main() -> int: + print(f"=== CEC Phase 5 Test 1 — diversity spike (TRAINED policy) ===") + print(f"Checkpoint: {CKPT}") + print(f"N_EPISODES={N_EPISODES}, N_STEPS={N_STEPS}") + print(f"JAX backend: {jax.default_backend()} ({jax.devices()})") + + if not CKPT.is_file(): + print(f"ERROR: checkpoint not found: {CKPT}", file=sys.stderr) + return 2 + + policy, params, recipe = load_jax_checkpoint(CKPT) + red_agent = recipe.get("train", {}).get("red_agent", "fsm") + print(f"Recipe red_agent: {red_agent}") + + env = make_fsm_red_env(num_steps=N_STEPS, topology_mode="generative", red_agent=red_agent) + rollout = make_rollout_fn(env, policy, params) + rollout_batched = jax.jit(jax.vmap(rollout, in_axes=(0, 0))) + + rng = jax.random.PRNGKey(0) + rng_split = jax.random.split(rng, 3) + init_keys_diverse = jax.random.split(rng_split[0], N_EPISODES) + rollout_keys_a = jax.random.split(rng_split[1], N_EPISODES) + rollout_keys_b = jax.random.split(rng_split[2], N_EPISODES) + + fixed_init_key = jax.random.PRNGKey(42) + init_keys_fixed = jnp.broadcast_to(fixed_init_key[None, :], (N_EPISODES, fixed_init_key.shape[0])) + + print("\nCompiling + running ENV-FIXED scan...") + t0 = time.perf_counter() + rewards_fixed = rollout_batched(init_keys_fixed, rollout_keys_a) + rewards_fixed.block_until_ready() + t_fixed = time.perf_counter() - t0 + print(f" wall: {t_fixed:.1f}s") + + print("\nRunning ENV-DIVERSE scan (re-uses compiled cache)...") + t0 = time.perf_counter() + rewards_diverse = rollout_batched(init_keys_diverse, rollout_keys_b) + rewards_diverse.block_until_ready() + t_div = time.perf_counter() - t0 + print(f" wall: {t_div:.1f}s") + + rf = np.asarray(rewards_fixed) + rd = np.asarray(rewards_diverse) + + print("\n--- Per-episode reward summary (TRAINED policy, deterministic) ---") + fmt = "{:14s} mean={:8.1f} std={:6.1f} min={:8.1f} max={:8.1f} range={:7.1f}" + print(fmt.format("ENV-FIXED", rf.mean(), rf.std(), rf.min(), rf.max(), rf.max() - rf.min())) + print(fmt.format("ENV-DIVERSE", rd.mean(), rd.std(), rd.min(), rd.max(), rd.max() - rd.min())) + + ratio = rd.std() / max(rf.std(), 1e-9) + print(f"\nstd ratio (diverse / fixed): {ratio:.2f}") + + THRESHOLD = 1.5 + if ratio >= THRESHOLD: + print(f"VERDICT: PASS (ratio {ratio:.2f} >= {THRESHOLD}) — env-diverse has real spread under a trained policy.") + return 0 + print(f"VERDICT: FAIL (ratio {ratio:.2f} < {THRESHOLD}) — even a trained policy doesn't see meaningful spread.") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/dev/cec_phase6_aggregate.py b/scripts/dev/cec_phase6_aggregate.py new file mode 100644 index 0000000..6912ed0 --- /dev/null +++ b/scripts/dev/cec_phase6_aggregate.py @@ -0,0 +1,216 @@ +"""Phase 6 Test 2 aggregator — paired delta C11−C00 across 3 seeds × 5 held-out reds. + +Reads result rows from ``$JAXBORG_EXP_DIR/eval/phase6_*.jsonl`` (written by +``cec_phase6_eval_jax.py``), pivots into a (arm, seed, red) → mean_reward +table, computes: + + - Per-red mean across seeds, ± stderr. + - Paired delta C11 − C00 per seed (pair by seed), then mean ± stderr. + - Pre-registered falsification verdict per the plan: + * confirmed: paired delta ≥ +200 reward AND lower bound > 0 + * refuted: paired delta ≤ +50 reward OR sign flip on ≥1 seed + * inconclusive otherwise + +Usage: + uv run python scripts/dev/cec_phase6_aggregate.py + uv run python scripts/dev/cec_phase6_aggregate.py --eval-dir /custom/eval/dir +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +from pathlib import Path + +# Override via PHASE6_ARMS env var (space-separated, e.g. "C00_10M C11_10M"). +ARMS = tuple(os.environ.get("PHASE6_ARMS", "C00 C11").split()) +# "random" was originally in the plan's noise-floor list, but the JAX-native +# eval doesn't include a random red (CybORG-only). Drop from the JAX agg. +REDS = ("fsm", "cia_c", "cia_i", "cia_a") +SEEDS = (42, 142, 242) + +CONFIRM_REWARD_DELTA = 200.0 +REFUTE_REWARD_DELTA = 50.0 + + +def _load_rows(eval_dir: Path) -> list[dict]: + files = sorted(eval_dir.glob("phase6_*.jsonl")) + rows = [] + for f in files: + try: + row = json.loads(f.read_text()) + rows.append(row) + except Exception as e: + print(f"WARN: failed to parse {f}: {e}") + return rows + + +def _stderr(values: list[float]) -> float: + n = len(values) + if n < 2: + return 0.0 + mean = sum(values) / n + var = sum((v - mean) ** 2 for v in values) / (n - 1) + return math.sqrt(var / n) + + +def _arm_from_recipe(recipe_name: str) -> str: + # Exact-suffix match: "cec_phase6_C00_10M" matches arm "C00_10M" but NOT + # arm "C00". Prevents 3M-tag rows from pooling with 10M-tag rows. + for arm in ARMS: + if recipe_name.endswith(f"_{arm}") or recipe_name == arm: + return arm + return "?" + + +def _seed_from_model(model_path: str) -> int: + import re + + m = re.search(r"seed(\d+)", model_path) + return int(m.group(1)) if m else -1 + + +def aggregate(eval_dir: Path) -> dict: + rows = _load_rows(eval_dir) + # Collect ALL per-episode rewards per (arm, seed, red) across every + # eval batch (e.g. seed=1000 + seed=2000 result rows), then mean across + # the union — that's a larger effective N per cell than any single batch. + per_episode_pool: dict[tuple[str, int, str], list[float]] = {} + for row in rows: + arm = _arm_from_recipe(row.get("recipe_name", "")) + seed = _seed_from_model(row.get("model", "")) + red = row.get("eval_red", "?") + if arm in ARMS and seed in SEEDS and red in REDS: + key = (arm, seed, red) + pe = row.get("per_episode") or [] + if pe: + per_episode_pool.setdefault(key, []).extend(float(x) for x in pe) + else: + # fallback for rows without per_episode (older format) + per_episode_pool.setdefault(key, []).append(float(row["mean_reward"])) + table: dict[tuple[str, int, str], float] = {k: sum(v) / len(v) for k, v in per_episode_pool.items() if v} + batches_per_cell: dict[tuple[str, int, str], int] = {k: len(v) for k, v in per_episode_pool.items()} + summary_n_episodes = batches_per_cell + + summary = { + "per_arm_red": {}, + "paired_deltas": {}, + "verdicts": {}, + "missing_cells": [], + } + for arm in ARMS: + for red in REDS: + vals = [table.get((arm, s, red)) for s in SEEDS] + present = [v for v in vals if v is not None] + for s, v in zip(SEEDS, vals): + if v is None: + summary["missing_cells"].append((arm, s, red)) + mean = sum(present) / len(present) if present else float("nan") + stderr = _stderr(present) if len(present) >= 2 else 0.0 + n_eps = [summary_n_episodes.get((arm, s, red), 0) for s in SEEDS] + summary["per_arm_red"][f"{arm}/{red}"] = { + "n": len(present), + "mean": mean, + "stderr": stderr, + "per_seed": vals, + "episodes_per_seed": n_eps, + } + + # Pick the two arms used for the paired-delta. With the default + # ("C00", "C11"), this matches the original 3M experiment. With + # PHASE6_ARMS="C00_10M C11_10M" it matches Option B. The first arm is the + # control, the second is the treatment; the delta is treatment − control. + if len(ARMS) < 2: + raise SystemExit(f"Need at least 2 arms for paired delta; got {ARMS}") + arm_ctrl, arm_trt = ARMS[0], ARMS[1] + for red in REDS: + deltas = [] + for s in SEEDS: + v00 = table.get((arm_ctrl, s, red)) + v11 = table.get((arm_trt, s, red)) + if v00 is None or v11 is None: + continue + deltas.append(v11 - v00) + n = len(deltas) + if n == 0: + summary["paired_deltas"][red] = None + summary["verdicts"][red] = "no-data" + continue + mean = sum(deltas) / n + stderr = _stderr(deltas) if n >= 2 else 0.0 + lb = mean - stderr + signs_match = all(d > 0 for d in deltas) or all(d < 0 for d in deltas) + summary["paired_deltas"][red] = { + "n": n, + "mean": mean, + "stderr": stderr, + "lower_bound": lb, + "per_seed": deltas, + "signs_match": signs_match, + } + if mean >= CONFIRM_REWARD_DELTA and lb > 0: + verdict = "CONFIRMED" + elif mean <= REFUTE_REWARD_DELTA or not signs_match: + verdict = "REFUTED" + else: + verdict = "INCONCLUSIVE" + summary["verdicts"][red] = verdict + + return summary + + +def _print_table(summary: dict) -> None: + print("\n=== Per-arm × held-out red mean reward (across seeds) ===") + print(f"{'arm/red':<12} {'n':>3} {'mean':>10} {'± stderr':>10} per-seed (eps/seed)") + for k, v in summary["per_arm_red"].items(): + ps = ", ".join(f"{x:.0f}" if x is not None else "—" for x in v["per_seed"]) + eps = ", ".join(str(e) for e in v.get("episodes_per_seed", [])) + print(f"{k:<12} {v['n']:>3} {v['mean']:>10.1f} {v['stderr']:>10.1f} [{ps}] ({eps} eps)") + + print("\n=== Paired delta (C11 − C00), per held-out red ===") + print(f"{'red':<10} {'n':>3} {'Δmean':>10} {'± stderr':>10} {'lower_bound':>12} verdict") + for red in REDS: + d = summary["paired_deltas"].get(red) + verdict = summary["verdicts"].get(red, "—") + if d is None: + print(f"{red:<10} {0:>3} {'—':>10} {'—':>10} {'—':>12} {verdict}") + else: + print(f"{red:<10} {d['n']:>3} {d['mean']:>10.1f} {d['stderr']:>10.1f} {d['lower_bound']:>12.1f} {verdict}") + + if summary["missing_cells"]: + print(f"\nMISSING ({len(summary['missing_cells'])} cells):") + for arm, seed, red in summary["missing_cells"]: + print(f" {arm} seed={seed} red={red}") + + print("\nPre-registered thresholds:") + print(f" CONFIRMED: paired Δ ≥ +{CONFIRM_REWARD_DELTA:.0f} reward AND mean−stderr > 0") + print(f" REFUTED: paired Δ ≤ +{REFUTE_REWARD_DELTA:.0f} OR sign flip across seeds") + print(" INCONCLUSIVE band in between → escalate (e.g. mission_bank_amplify=10 or more seeds)") + + +def main(): + parser = argparse.ArgumentParser(description="Aggregate Phase 6 Test 2 eval results") + parser.add_argument( + "--eval-dir", + type=str, + default=os.environ.get("JAXBORG_EXP_DIR", "jaxborg-exp") + "/eval", + help="Directory containing phase6_*.jsonl eval rows", + ) + parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON instead of a table") + args = parser.parse_args() + + eval_dir = Path(args.eval_dir).resolve() + if not eval_dir.is_dir(): + raise SystemExit(f"eval-dir not found: {eval_dir}") + + summary = aggregate(eval_dir) + if args.json: + print(json.dumps(summary, indent=2, default=str)) + else: + _print_table(summary) + + +if __name__ == "__main__": + main() diff --git a/scripts/dev/topology_bank/shape_00.snapshot.npz b/scripts/dev/topology_bank/shape_00.snapshot.npz new file mode 100644 index 0000000..d2879cb Binary files /dev/null and b/scripts/dev/topology_bank/shape_00.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_01.snapshot.npz b/scripts/dev/topology_bank/shape_01.snapshot.npz new file mode 100644 index 0000000..c6a7a7b Binary files /dev/null and b/scripts/dev/topology_bank/shape_01.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_02.snapshot.npz b/scripts/dev/topology_bank/shape_02.snapshot.npz new file mode 100644 index 0000000..ac117fd Binary files /dev/null and b/scripts/dev/topology_bank/shape_02.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_03.snapshot.npz b/scripts/dev/topology_bank/shape_03.snapshot.npz new file mode 100644 index 0000000..9540a59 Binary files /dev/null and b/scripts/dev/topology_bank/shape_03.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_04.snapshot.npz b/scripts/dev/topology_bank/shape_04.snapshot.npz new file mode 100644 index 0000000..e607eb4 Binary files /dev/null and b/scripts/dev/topology_bank/shape_04.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_05.snapshot.npz b/scripts/dev/topology_bank/shape_05.snapshot.npz new file mode 100644 index 0000000..1046af8 Binary files /dev/null and b/scripts/dev/topology_bank/shape_05.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_06.snapshot.npz b/scripts/dev/topology_bank/shape_06.snapshot.npz new file mode 100644 index 0000000..0f4ea8d Binary files /dev/null and b/scripts/dev/topology_bank/shape_06.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_07.snapshot.npz b/scripts/dev/topology_bank/shape_07.snapshot.npz new file mode 100644 index 0000000..8cb8176 Binary files /dev/null and b/scripts/dev/topology_bank/shape_07.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_08.snapshot.npz b/scripts/dev/topology_bank/shape_08.snapshot.npz new file mode 100644 index 0000000..c044c77 Binary files /dev/null and b/scripts/dev/topology_bank/shape_08.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_09.snapshot.npz b/scripts/dev/topology_bank/shape_09.snapshot.npz new file mode 100644 index 0000000..2f54096 Binary files /dev/null and b/scripts/dev/topology_bank/shape_09.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_10.snapshot.npz b/scripts/dev/topology_bank/shape_10.snapshot.npz new file mode 100644 index 0000000..5162f42 Binary files /dev/null and b/scripts/dev/topology_bank/shape_10.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_11.snapshot.npz b/scripts/dev/topology_bank/shape_11.snapshot.npz new file mode 100644 index 0000000..a084598 Binary files /dev/null and b/scripts/dev/topology_bank/shape_11.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_12.snapshot.npz b/scripts/dev/topology_bank/shape_12.snapshot.npz new file mode 100644 index 0000000..5ffb561 Binary files /dev/null and b/scripts/dev/topology_bank/shape_12.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_13.snapshot.npz b/scripts/dev/topology_bank/shape_13.snapshot.npz new file mode 100644 index 0000000..108d2a9 Binary files /dev/null and b/scripts/dev/topology_bank/shape_13.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_14.snapshot.npz b/scripts/dev/topology_bank/shape_14.snapshot.npz new file mode 100644 index 0000000..eba6a1c Binary files /dev/null and b/scripts/dev/topology_bank/shape_14.snapshot.npz differ diff --git a/scripts/dev/topology_bank/shape_15.snapshot.npz b/scripts/dev/topology_bank/shape_15.snapshot.npz new file mode 100644 index 0000000..df9a2ae Binary files /dev/null and b/scripts/dev/topology_bank/shape_15.snapshot.npz differ diff --git a/scripts/eval/cec_phase6_eval_jax.py b/scripts/eval/cec_phase6_eval_jax.py new file mode 100644 index 0000000..baaeb48 --- /dev/null +++ b/scripts/eval/cec_phase6_eval_jax.py @@ -0,0 +1,184 @@ +"""JAX-native eval for Phase 6 Test 2 — held-out red sweep. + +Loads a JAX-trained checkpoint, builds a JAX env with the eval variant +(red overridable via ``--eval-red``), runs ``--episodes`` deterministic- +argmax rollouts vmapped, and writes a result row to a JSONL file. + +The plan's eval pipeline routes via eval_recipe.py (CybORG-side), but for +held-out-generalization claims the plan also accepts JAX-internal eval +(plan §"Training & implementation" / "Eval"). This script is the +JAX-internal version: cheaper, faster, no CybORG dependency, and the +held-out signal is a JAX-internal claim regardless. + +Usage: + uv run python scripts/eval/cec_phase6_eval_jax.py \\ + --model jaxborg-exp/ippo_jax/cec_phase6_C11_seed42/model_cec_phase6_C11_seed42.safetensors \\ + --eval-red cia_c --episodes 90 --seed 1000 + +Result schema mirrors eval_recipe.py for downstream aggregation. +""" + +from __future__ import annotations + +import argparse +import json +import os +import time +from pathlib import Path +from statistics import mean, stdev + +# Set persistent XLA compile cache BEFORE importing jax — non-interactive bash +# (slurm --wrap, bash -c) doesn't source ~/.bashrc so we set this defensively +# so re-runs hit the cache instead of paying ~20 min CPU compile per cell. +os.environ.setdefault("JAX_ENABLE_COMPILATION_CACHE", "1") +os.environ.setdefault("JAX_COMPILATION_CACHE_DIR", os.path.expanduser("~/.cache/jaxborg/xla")) +os.environ.setdefault("JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS", "1") + +import jax +import jax.numpy as jnp +import numpy as np + +from jaxborg.constants import NUM_BLUE_AGENTS +from jaxborg.evaluation.jax_env_factory import make_jax_env +from jaxborg.evaluation.jax_runner import load_jax_checkpoint +from jaxborg.scenarios.cc4.game_variants import variant_for_red + +EXP_DIR = Path(os.environ.get("JAXBORG_EXP_DIR", "jaxborg-exp")).resolve() + + +def _git_commit() -> str: + import subprocess + + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL).decode().strip() + except Exception: + return "" + + +def _build_eval_env(variant_name: str, *, resilience_roles: bool): + """Build a clean canonical-config eval env for the given red. + + Eval intentionally uses NO env-diversity banks — the held-out + generalization claim is "policy trained on diverse env handles a held-out + red on the canonical env." Topology is the canonical generative one; + mission profile is (1, 1, 1); phase boundaries are canonical. + """ + variant = variant_for_red(variant_name, resilience_roles=resilience_roles) + return variant, make_jax_env(variant) + + +def run_eval( + *, + model_path: Path, + eval_red: str, + episodes: int, + seed: int, +) -> dict: + policy, params, recipe = load_jax_checkpoint(model_path) + train_red = recipe.get("train", {}).get("variant", "cc4_stock") + # CIA-biased reds need resilience_roles=True for their selectors; + # cc4_stock and fsm reds don't. + resilience_roles = eval_red in ("cia_c", "cia_i", "cia_a", "resilience", "c", "i", "a") + variant, env = _build_eval_env(eval_red, resilience_roles=resilience_roles) + + blue_agents = tuple(f"blue_{i}" for i in range(NUM_BLUE_AGENTS)) + num_steps = variant.num_steps + + @jax.jit + def _run_one(key): + reset_key, scan_key = jax.random.split(key) + obs, env_state = env.reset(reset_key) + mask = env.get_avail_actions(env_state) + + def step_fn(carry, _): + state, obs, mask, k = carry + k, step_key = jax.random.split(k) + obs_stack = jnp.stack([obs[a] for a in blue_agents]) + mask_stack = jnp.stack([mask[a] for a in blue_agents]) + + def _fwd(o, m): + pi, _ = policy.apply(params, o, m) + return pi.logits + + logits = jax.vmap(_fwd)(obs_stack, mask_stack) + acts = jnp.argmax(logits, axis=-1) + actions = {a: acts[i] for i, a in enumerate(blue_agents)} + new_obs, new_state, rewards, _, _ = env.step(step_key, state, actions) + new_mask = env.get_avail_actions(new_state) + mean_reward = jnp.stack([rewards[a] for a in blue_agents]).mean() + return (new_state, new_obs, new_mask, k), mean_reward + + (_, _, _, _), per_step = jax.lax.scan(step_fn, (env_state, obs, mask, scan_key), None, length=num_steps) + return per_step.sum() + + keys = jax.random.split(jax.random.PRNGKey(seed), episodes) + t0 = time.perf_counter() + totals = jax.vmap(_run_one)(keys) + totals.block_until_ready() + wall = time.perf_counter() - t0 + rewards_list = [float(x) for x in np.asarray(totals)] + + return { + "model": str(model_path), + "recipe_name": recipe.get("meta", {}).get("name", ""), + "trained_backend": "jax", + "eval_env": "jax", + "eval_red": eval_red, + "variant": variant.name, + "train_variant": train_red, + "seed": seed, + "episodes": episodes, + "deterministic": True, + "mean_reward": mean(rewards_list), + "std_reward": stdev(rewards_list) if len(rewards_list) > 1 else 0.0, + "n_episodes": len(rewards_list), + "wall_time_s": wall, + "git_commit": _git_commit(), + "per_episode": rewards_list, + } + + +def main(): + parser = argparse.ArgumentParser(description="JAX-native held-out red eval for Phase 6 Test 2") + parser.add_argument("--model", required=True, help=".safetensors checkpoint with sibling recipe sidecar") + # NOTE: "random" (CybORG's RandomSelectRedAgent) is not in the JAX red- + # selector REGISTRY; route that through eval_recipe.py (CybORG eval). + parser.add_argument( + "--eval-red", required=True, choices=["fsm", "cia_c", "cia_i", "cia_a", "resilience", "sleep"] + ) + parser.add_argument("--episodes", type=int, default=90, help="Episodes (plan default 90 for stat power)") + parser.add_argument("--seed", type=int, default=1000, help="PRNG seed for the rollout batch") + parser.add_argument("--output", type=str, default=None, help="Override result jsonl path") + args = parser.parse_args() + + model_path = Path(args.model).resolve() + if not model_path.exists(): + raise FileNotFoundError(f"Model not found: {model_path}") + + print(f"=== JAX eval: {model_path.name} vs {args.eval_red} ({args.episodes} eps, seed {args.seed}) ===", flush=True) + print(f"JAX backend: {jax.default_backend()} ({jax.devices()})", flush=True) + + row = run_eval( + model_path=model_path, + eval_red=args.eval_red, + episodes=args.episodes, + seed=args.seed, + ) + eval_id = f"{time.strftime('%Y%m%d_%H%M%S')}_{args.seed}_{args.eval_red}" + row["eval_id"] = eval_id + + out_dir = EXP_DIR / "eval" + out_dir.mkdir(parents=True, exist_ok=True) + if args.output: + out_path = Path(args.output) + else: + out_path = out_dir / f"phase6_{row['recipe_name']}_{model_path.stem}_{eval_id}.jsonl" + out_path.write_text(json.dumps(row, indent=2) + "\n") + + print(f"\nmean: {row['mean_reward']:.2f} ± {row['std_reward']:.2f} (n={row['n_episodes']})", flush=True) + print(f"wall: {row['wall_time_s']:.1f}s", flush=True) + print(f"wrote: {out_path}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval/cec_phase6_eval_sweep.py b/scripts/eval/cec_phase6_eval_sweep.py new file mode 100644 index 0000000..219d330 --- /dev/null +++ b/scripts/eval/cec_phase6_eval_sweep.py @@ -0,0 +1,82 @@ +"""Phase 6 Test 2 eval sweep — single-process driver. + +Loops through (checkpoint, held-out red) cells sequentially in ONE Python +process so JIT compilation amortizes: 4 reds × 1 compile each, not 24 +compiles across separate processes. ~5× faster on a single machine. + +Output: one phase6_*.jsonl row per cell, identical schema to +``cec_phase6_eval_jax.py``. + +Usage: + JAX_PLATFORMS=cpu uv run python scripts/eval/cec_phase6_eval_sweep.py \\ + --episodes 300 --seed 2000 +""" + +from __future__ import annotations + +import argparse +import json +import os +import time +from pathlib import Path + +# Import the run_eval helper from the per-cell script — shares the JIT +# rollout body so behavior is identical. +import sys + +_REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_REPO_ROOT / "scripts" / "eval")) +from cec_phase6_eval_jax import run_eval # noqa: E402 + +EXP_DIR = Path(os.environ.get("JAXBORG_EXP_DIR", "jaxborg-exp")).resolve() + +ARMS = ("C00", "C11") +SEEDS = (42, 142, 242) +REDS = ("fsm", "cia_c", "cia_i", "cia_a") + + +def main(): + parser = argparse.ArgumentParser(description="Phase 6 Test 2 sequential eval sweep") + parser.add_argument("--episodes", type=int, default=300) + parser.add_argument("--seed", type=int, default=2000, help="Rollout PRNG seed") + parser.add_argument("--reds", nargs="+", default=list(REDS)) + parser.add_argument("--arms", nargs="+", default=list(ARMS)) + parser.add_argument("--train-seeds", nargs="+", type=int, default=list(SEEDS)) + args = parser.parse_args() + + out_dir = EXP_DIR / "eval" + out_dir.mkdir(parents=True, exist_ok=True) + + cells = [] + for arm in args.arms: + for ts in args.train_seeds: + tag = f"cec_phase6_{arm}_seed{ts}" + model = EXP_DIR / "ippo_jax" / tag / f"model_{tag}.safetensors" + if not model.is_file(): + print(f"SKIP missing checkpoint: {model}", flush=True) + continue + for red in args.reds: + cells.append((tag, model, red)) + + print(f"=== sweep: {len(cells)} cells, episodes={args.episodes}, seed={args.seed} ===", flush=True) + t0_all = time.perf_counter() + + # Group by red so JIT cache hits within each red across all 6 checkpoints. + cells.sort(key=lambda c: (c[2], c[0])) + + for i, (tag, model, red) in enumerate(cells, 1): + print(f"\n[{i}/{len(cells)}] {tag} vs {red}", flush=True) + t0 = time.perf_counter() + row = run_eval(model_path=model, eval_red=red, episodes=args.episodes, seed=args.seed) + eval_id = f"{time.strftime('%Y%m%d_%H%M%S')}_{args.seed}_{red}" + row["eval_id"] = eval_id + out_path = out_dir / f"phase6_{row['recipe_name']}_{model.stem}_{eval_id}.jsonl" + out_path.write_text(json.dumps(row, indent=2) + "\n") + print(f" mean={row['mean_reward']:.1f} ± {row['std_reward']:.1f} n={row['n_episodes']} wall={time.perf_counter() - t0:.1f}s", flush=True) + print(f" wrote {out_path.name}", flush=True) + + print(f"\nTotal wall: {time.perf_counter() - t0_all:.1f}s", flush=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval/cec_phase6_eval_test2.sh b/scripts/eval/cec_phase6_eval_test2.sh new file mode 100755 index 0000000..7841d66 --- /dev/null +++ b/scripts/eval/cec_phase6_eval_test2.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash +# Phase 6 Test 2 eval orchestrator — runs JAX-native held-out red sweep on CPU. +# +# Per checkpoint × held-out red, runs scripts/eval/cec_phase6_eval_jax.py +# at --episodes 90 (plan default for stat power; --episodes 30 for a smoke +# pass). Each (ckpt, red) job is independent and embarrassingly parallel — +# we submit them as individual CPU sbatch jobs so they fan out across the +# cluster without competing with training/diagnostic GPU jobs. +# +# Total: 6 ckpts × 5 reds = 30 jobs. CPU JAX is mostly bottlenecked on +# JIT compile (~3 min) plus a sub-minute rollout — ~5 min wall per cell, +# ~30 min total if 5+ jobs run concurrently on the cluster. +# +# Usage: +# ./scripts/eval/cec_phase6_eval_test2.sh # full sweep +# ./scripts/eval/cec_phase6_eval_test2.sh --episodes 30 # smoke pass +# ./scripts/eval/cec_phase6_eval_test2.sh --dry-run # print commands +# ./scripts/eval/cec_phase6_eval_test2.sh --arm C11 --red cia_c # one cell + +set -euo pipefail + +WORKDIR=$(git rev-parse --show-toplevel) +EXP_DIR="${JAXBORG_EXP_DIR:-$WORKDIR/../../jaxborg-exp}" +SLURM_LOG_DIR="$EXP_DIR/slurm" +mkdir -p "$SLURM_LOG_DIR" + +if [ -n "${PHASE6_ARMS:-}" ]; then + # shellcheck disable=SC2206 + ARMS=($PHASE6_ARMS) +else + ARMS=("C00" "C11") +fi +SEEDS=(42 142 242) +# RandomSelectRedAgent is a CybORG-side construct; the JAX red-selector +# REGISTRY only has fsm + the CIA-biased reds, so the JAX-native eval skips +# "random" here. To get a random-red noise-floor row, run eval_recipe.py +# (CybORG-side) with --eval-red random instead. +REDS=("fsm" "cia_c" "cia_i" "cia_a") +EPISODES=90 +EVAL_SEED=1000 +DRY=0 +ARM_FILTER="" +RED_FILTER="" + +while [ "$#" -gt 0 ]; do + case "$1" in + --episodes) EPISODES="$2"; shift 2 ;; + --seed) EVAL_SEED="$2"; shift 2 ;; + --arm) ARM_FILTER="$2"; shift 2 ;; + --red) RED_FILTER="$2"; shift 2 ;; + --dry-run) DRY=1; shift ;; + *) echo "unrecognized arg: $1" >&2; exit 2 ;; + esac +done + +submit_one() { + local tag="$1" + local red="$2" + local model="$EXP_DIR/ippo_jax/${tag}/model_${tag}.safetensors" + if [ ! -f "$model" ]; then + echo "SKIP $tag vs $red — model not found: $model" >&2 + return + fi + local jobname="eval_${tag}_${red}" + # CPU-only — eval jobs do not need GPU and should not compete with + # training jobs for GPU allocation. JAX_PLATFORMS=cpu pins the backend. + local cmd=( + sbatch + --cpus-per-task=8 + --mem=32G + --time=01:30:00 + --partition=community + --job-name="${jobname}" + --output="${SLURM_LOG_DIR}/${jobname}_%j.log" + --wrap "set -eu +cd ${WORKDIR} +JAX_PLATFORMS=cpu JAXBORG_EXP_DIR=${EXP_DIR} uv run python scripts/eval/cec_phase6_eval_jax.py --model ${model} --eval-red ${red} --episodes ${EPISODES} --seed ${EVAL_SEED}" + ) + if [ "$DRY" -eq 1 ]; then + printf '%q ' "${cmd[@]}"; echo + else + "${cmd[@]}" + fi +} + +for arm in "${ARMS[@]}"; do + if [ -n "$ARM_FILTER" ] && [ "$arm" != "$ARM_FILTER" ]; then continue; fi + for seed in "${SEEDS[@]}"; do + tag="cec_phase6_${arm}_seed${seed}" + for red in "${REDS[@]}"; do + if [ -n "$RED_FILTER" ] && [ "$red" != "$RED_FILTER" ]; then continue; fi + submit_one "$tag" "$red" + done + done +done + +if [ "$DRY" -eq 0 ]; then + echo + echo "Submitted. After completion, aggregate with:" + echo " uv run python scripts/dev/cec_phase6_aggregate.py --eval-dir ${EXP_DIR}/eval" +fi diff --git a/scripts/eval/eval_recipe.py b/scripts/eval/eval_recipe.py index bc169a9..dbd5ebc 100644 --- a/scripts/eval/eval_recipe.py +++ b/scripts/eval/eval_recipe.py @@ -93,6 +93,17 @@ def main(): help="Parallel rollout workers (1 = single process). Default: cpu_count() - 2.", ) parser.add_argument("--output", type=str, default=None, help="Override result jsonl path") + parser.add_argument( + "--eval-red", + type=str, + default=None, + help=( + "Force the eval env's red selector regardless of recipe. " + "Precedence: CLI --eval-red > recipe eval.red > recipe eval.variant > " + "train.variant. One of: fsm | cia_c | cia_i | cia_a | random | " + "resilience | sleep." + ), + ) args = parser.parse_args() model_path = Path(args.model).resolve() @@ -107,6 +118,8 @@ def main(): from jaxborg.recipe import eval_variant recipe = read_sidecar(model_path) + if args.eval_red is not None: + recipe.setdefault("eval", {})["red"] = args.eval_red variant = eval_variant(recipe) print(f"Loaded recipe sidecar: {recipe.get('meta', {}).get('name', '?')}", flush=True) print( @@ -130,6 +143,8 @@ def main(): from jaxborg.recipe import eval_variant recipe = read_sidecar(model_path) + if args.eval_red is not None: + recipe.setdefault("eval", {})["red"] = args.eval_red variant = eval_variant(recipe) t0 = time.perf_counter() diff --git a/scripts/eval/score_trajectories.py b/scripts/eval/score_trajectories.py index 11cb362..54314a4 100644 --- a/scripts/eval/score_trajectories.py +++ b/scripts/eval/score_trajectories.py @@ -40,9 +40,24 @@ def main(): parser.add_argument("--summary-json", default=None) parser.add_argument("--per-episode-json", default=None) parser.add_argument("--recipe", default=None, help="Path or name of recipe yaml") + parser.add_argument( + "--eval-red", + default=None, + help=( + "Override the recipe's eval.red selector before resolving the " + "variant. CLI > recipe. Only meaningful if the scorer or its " + "downstream consumers branch on the variant." + ), + ) args = parser.parse_args() - eval_cfg = project_eval(load(args.recipe)) if args.recipe is not None else {} + if args.recipe is not None: + recipe = load(args.recipe) + if args.eval_red is not None: + recipe.setdefault("eval", {})["red"] = args.eval_red + eval_cfg = project_eval(recipe) + else: + eval_cfg = {} scorer = get_cia_scorer(eval_cfg) traj_dir = Path(args.traj_dir) diff --git a/scripts/train/algorithms/ippo_jax.py b/scripts/train/algorithms/ippo_jax.py index d3b4fee..a0a2dc9 100644 --- a/scripts/train/algorithms/ippo_jax.py +++ b/scripts/train/algorithms/ippo_jax.py @@ -87,9 +87,15 @@ def make_train(config, network): """Build env and a single JIT'd collect_and_update fn from a flat config.""" num_envs = config["NUM_ENVS"] variant: GameVariant = config["TRAIN_VARIANT"] + topology_bank = config.get("TOPOLOGY_BANK") or None inner_env = make_jax_env( variant, training_mode=bool(config.get("TRAINING_MODE", True)), + topology_path=list(topology_bank) if topology_bank else None, + mission_bank=config.get("MISSION_BANK"), + mission_bank_amplify=config.get("MISSION_BANK_AMPLIFY", 1.0), + phase_boundary_bank=config.get("PHASE_BOUNDARY_BANK"), + phase_rewards_bank=config.get("PHASE_REWARDS_BANK"), ) agents = list(inner_env.agents) num_agents = inner_env.num_agents @@ -335,9 +341,15 @@ def main(): # Build a throwaway env to get action_dim for network init. variant: GameVariant = config["TRAIN_VARIANT"] + topology_bank = config.get("TOPOLOGY_BANK") or None inner_env = make_jax_env( variant, topology_mode=config.get("TOPOLOGY_MODE", "generative"), + topology_path=list(topology_bank) if topology_bank else None, + mission_bank=config.get("MISSION_BANK"), + mission_bank_amplify=config.get("MISSION_BANK_AMPLIFY", 1.0), + phase_boundary_bank=config.get("PHASE_BOUNDARY_BANK"), + phase_rewards_bank=config.get("PHASE_REWARDS_BANK"), ) action_dim = inner_env.action_space(inner_env.agents[0]).n network = make_jax_policy( diff --git a/scripts/train/cec_phase6_optionb_ablation.sh b/scripts/train/cec_phase6_optionb_ablation.sh new file mode 100755 index 0000000..30b02d5 --- /dev/null +++ b/scripts/train/cec_phase6_optionb_ablation.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# Phase 6 Option B + cocktail ablation launcher. +# +# Submits 18 training jobs (6 arms × 3 seeds × 10M timesteps) chained via +# slurm --dependency=afterany so they execute SERIALLY on one GPU slot. +# This avoids crowding GPUs that other users may need. +# +# Arms (each 10M timesteps): +# C00_10M — canonical control (no banks) +# C11_10M — full env-diversity cocktail (all 4 banks) +# topo_10M — ablation: topology bank only +# miss_10M — ablation: mission bank only +# pbound_10M — ablation: phase-boundary bank only +# cjewel_10M — ablation: phase-rewards (crown-jewel) bank only +# +# Estimated wall: ~1.5 hr per run × 18 runs = ~27 hr on 1 GPU. +# +# Usage: +# ./scripts/train/cec_phase6_optionb_ablation.sh # submit all +# ./scripts/train/cec_phase6_optionb_ablation.sh --dry-run # print only +# ./scripts/train/cec_phase6_optionb_ablation.sh --arm C00_10M # one arm +# ./scripts/train/cec_phase6_optionb_ablation.sh --no-dep # parallel (no chain) + +set -euo pipefail + +WORKDIR=$(git rev-parse --show-toplevel) +EXP_DIR="${JAXBORG_EXP_DIR:-$WORKDIR/../../jaxborg-exp}" +SLURM_LOG_DIR="$EXP_DIR/slurm" +mkdir -p "$SLURM_LOG_DIR" + +ARMS=("C00_10M" "C11_10M" "topo_10M" "miss_10M" "pbound_10M" "cjewel_10M") +SEEDS=(42 142 242) + +DRY=0 +ARM_FILTER="" +SEED_FILTER="" +NO_DEP=0 +while [ "$#" -gt 0 ]; do + case "$1" in + --dry-run) DRY=1; shift ;; + --no-dep) NO_DEP=1; shift ;; + --arm) ARM_FILTER="$2"; shift 2 ;; + --seed) SEED_FILTER="$2"; shift 2 ;; + *) echo "unrecognized arg: $1" >&2; exit 2 ;; + esac +done + +PREV_JID="" +submit_one() { + local arm="$1" + local seed="$2" + local recipe="cec_phase6_${arm}" + local tag="${recipe}_seed${seed}" + local dep_args=() + if [ "$NO_DEP" -eq 0 ] && [ -n "$PREV_JID" ]; then + dep_args=(--dependency=afterany:"$PREV_JID") + fi + local cmd=( + sbatch + --parsable + "${dep_args[@]}" + --gres=gpu:1 + --mem=64G + --time=14:00:00 + --partition=community + --job-name="${tag}" + --output="${SLURM_LOG_DIR}/${tag}_%j.log" + --wrap "set -eu +cd ${WORKDIR} +unset JAX_PLATFORMS +JAXBORG_EXP_DIR=${EXP_DIR} uv run --extra cuda python scripts/train/algorithms/ippo_jax.py --recipe ${recipe} --seed ${seed} --tag ${tag}" + ) + if [ "$DRY" -eq 1 ]; then + printf '%q ' "${cmd[@]}"; echo + PREV_JID="DRY" + else + local jid + jid=$("${cmd[@]}") + echo "submitted $tag → job $jid${PREV_JID:+ (depends on $PREV_JID)}" + PREV_JID="$jid" + fi +} + +for arm in "${ARMS[@]}"; do + if [ -n "$ARM_FILTER" ] && [ "$arm" != "$ARM_FILTER" ]; then continue; fi + for seed in "${SEEDS[@]}"; do + if [ -n "$SEED_FILTER" ] && [ "$seed" != "$SEED_FILTER" ]; then continue; fi + submit_one "$arm" "$seed" + done +done + +if [ "$DRY" -eq 0 ]; then + echo + echo "Watch with: squeue -u \$USER" + echo "Logs: $SLURM_LOG_DIR/cec_phase6_*_10M_seed*.log" +fi diff --git a/scripts/train/cec_phase6_test2.sh b/scripts/train/cec_phase6_test2.sh new file mode 100755 index 0000000..9c539ef --- /dev/null +++ b/scripts/train/cec_phase6_test2.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +# Phase 6 Test 2 launcher — submits 6 sbatch jobs: +# 2 arms (C00 control, C11 full env-diversity cocktail) × 3 seeds (42, 142, 242) +# +# Each job runs ippo_jax with --extra cuda so JAX uses the allocated A6000. +# Total wall: ~1.5 hr per run × 6 runs / 3 GPUs = ~3 hr (parallel) / ~9 hr (serial). +# +# Usage: +# ./scripts/train/cec_phase6_test2.sh # submit all 6 +# ./scripts/train/cec_phase6_test2.sh --dry-run # print sbatch commands without submitting +# ./scripts/train/cec_phase6_test2.sh C00 42 # submit one specific run +# ./scripts/train/cec_phase6_test2.sh C11 # submit only C11 × all seeds + +set -euo pipefail + +WORKDIR=$(git rev-parse --show-toplevel) +EXP_DIR="${JAXBORG_EXP_DIR:-$WORKDIR/../../jaxborg-exp}" +SLURM_LOG_DIR="$EXP_DIR/slurm" +mkdir -p "$SLURM_LOG_DIR" + +ARMS=("C00" "C11") +SEEDS=(42 142 242) + +DRY=0 +ARM_FILTER="" +SEED_FILTER="" +for arg in "$@"; do + case "$arg" in + --dry-run) DRY=1 ;; + C00|C11) ARM_FILTER="$arg" ;; + [0-9]*) SEED_FILTER="$arg" ;; + *) echo "unrecognized arg: $arg" >&2; exit 2 ;; + esac +done + +submit_one() { + local arm="$1" + local seed="$2" + local recipe="cec_phase6_${arm}" + local tag="${recipe}_seed${seed}" + local cmd=( + sbatch + --gres=gpu:1 + --mem=64G + --time=04:00:00 + --partition=community + --job-name="${tag}" + --output="${SLURM_LOG_DIR}/${tag}_%j.log" + --wrap "set -eu +cd ${WORKDIR} +unset JAX_PLATFORMS +JAXBORG_EXP_DIR=${EXP_DIR} uv run --extra cuda python scripts/train/algorithms/ippo_jax.py --recipe ${recipe} --seed ${seed} --tag ${tag}" + ) + if [ "$DRY" -eq 1 ]; then + printf '%q ' "${cmd[@]}"; echo + else + "${cmd[@]}" + fi +} + +for arm in "${ARMS[@]}"; do + if [ -n "$ARM_FILTER" ] && [ "$arm" != "$ARM_FILTER" ]; then continue; fi + for seed in "${SEEDS[@]}"; do + if [ -n "$SEED_FILTER" ] && [ "$seed" != "$SEED_FILTER" ]; then continue; fi + submit_one "$arm" "$seed" + done +done + +if [ "$DRY" -eq 0 ]; then + echo + echo "Submitted. Watch with: squeue -u \$USER" + echo "Logs: $SLURM_LOG_DIR/cec_phase6_*_*.log" +fi diff --git a/src/jaxborg/env.py b/src/jaxborg/env.py index 615a0ed..395d5af 100644 --- a/src/jaxborg/env.py +++ b/src/jaxborg/env.py @@ -314,6 +314,10 @@ def __init__( topology_path: str | Path | Sequence[str | Path] | None = None, scenario_config: ScenarioConfig = CC4_CONFIG, op_zone_min_servers: int | None = None, + mission_bank: Sequence[Sequence[float]] | None = None, + mission_bank_amplify: float = 1.0, + phase_boundary_bank: Sequence[Sequence[int]] | None = None, + phase_rewards_bank: Sequence | None = None, ): self.cfg = scenario_config self.num_steps = num_steps if num_steps is not None else scenario_config.max_steps @@ -321,6 +325,60 @@ def __init__( self.op_zone_min_servers = op_zone_min_servers self._const_bank = None self._const_bank_size = 0 + + # Phase 6 axis B — per-reset CIA-component multiplier bank. When + # ``mission_bank`` is None or empty, we behave exactly as before + # (no per-reset variation, multipliers are implicitly (1, 1, 1)). + # Otherwise stack the bank as a (N, 3) float32 array on self and + # sample an index in ``_select_const``. ``mission_bank_amplify`` + # scales the *entire* sampled triple element-wise — so amplify=10 + # with bank entry (1, 3, 1) yields (10, 30, 10), not (1, 30, 1). + if mission_bank is None or len(mission_bank) == 0: + self._mission_bank = None + self._mission_bank_size = 0 + else: + arr = jnp.asarray(mission_bank, dtype=jnp.float32) + if arr.ndim != 2 or arr.shape[1] != 3: + raise ValueError(f"mission_bank must be a sequence of 3-tuples (LWF, ASF, RIA); got shape {arr.shape}") + amplify = float(mission_bank_amplify) + self._mission_bank = arr * amplify + self._mission_bank_size = int(arr.shape[0]) + + # Phase 6 P2 — phase-boundary jitter bank. Per-reset sample of + # (phase0_start, phase1_start, phase2_start). Replaces + # ``const.phase_boundaries`` so phase transitions, allow-list flips, + # and phase_rewards index switches all move with the sampled split. + # Empty/None → fast path, behavior unchanged. + if phase_boundary_bank is None or len(phase_boundary_bank) == 0: + self._phase_boundary_bank = None + self._phase_boundary_bank_size = 0 + else: + pb_arr = jnp.asarray(phase_boundary_bank, dtype=jnp.int32) + if pb_arr.ndim != 2 or pb_arr.shape[1] != 3: + raise ValueError( + f"phase_boundary_bank must be a sequence of 3-tuples; got shape {pb_arr.shape}" + ) + self._phase_boundary_bank = pb_arr + self._phase_boundary_bank_size = int(pb_arr.shape[0]) + + # Phase 6 P3 — phase_rewards bank (crown-jewel rotation). Per-reset + # sample of an entire (MISSION_PHASES, NUM_SUBNETS, 3) phase_rewards + # array. Each entry rotates which subnet is "high-value" in which + # phase, so the same physical topology generates different reward + # gradients per episode and the policy must read state instead of + # memorizing subnet indices. Empty/None → fast path, unchanged. + if phase_rewards_bank is None or len(phase_rewards_bank) == 0: + self._phase_rewards_bank = None + self._phase_rewards_bank_size = 0 + else: + pr_arr = jnp.asarray(phase_rewards_bank, dtype=jnp.float32) + if pr_arr.ndim != 4: + raise ValueError( + "phase_rewards_bank entries must be (MISSION_PHASES, NUM_SUBNETS, 3) " + f"arrays; got shape {pr_arr.shape}" + ) + self._phase_rewards_bank = pr_arr + self._phase_rewards_bank_size = int(pr_arr.shape[0]) if topology_path is not None: if topology_mode != "generative": raise ValueError( @@ -357,21 +415,56 @@ def __init__( self.observation_spaces[agent] = Box(low=0.0, high=1.0, shape=(self.cfg.blue_obs_size,), dtype=jnp.float32) def _select_const(self, key: chex.PRNGKey) -> SimulatorConst: + # Split the input key so each bank's index sample is independent of + # (and reproducible from) the same input key — same input key → + # same const + same triples + same boundaries + same crown-jewel. + key_const, key_mission, key_pb, key_pr = jax.random.split(key, 4) + if self._const_bank is None: - return build_topology( - key, + const = build_topology( + key_const, num_steps=self.num_steps, training_mode=self.training_mode, op_zone_min_servers=self.op_zone_min_servers, ) - - bank_idx = jax.random.randint(key, (), 0, self._const_bank_size) - const = jax.tree.map(lambda x: x[bank_idx], self._const_bank) - # Snapshots save the ``max_steps`` they were generated against (e.g. - # 500), but the env's caller may want a different episode length via - # ``num_steps``. Override here so ``done = state.time >= max_steps`` - # honours the env's configuration rather than the snapshot's default. - return const.replace(max_steps=jnp.int32(self.num_steps)) + else: + bank_idx = jax.random.randint(key_const, (), 0, self._const_bank_size) + const = jax.tree.map(lambda x: x[bank_idx], self._const_bank) + # Snapshots save the ``max_steps`` they were generated against + # (e.g. 500), but the env's caller may want a different episode + # length via ``num_steps``. Override here so ``done = state.time + # >= max_steps`` honours the env's configuration rather than the + # snapshot's default. + const = const.replace(max_steps=jnp.int32(self.num_steps)) + + # Phase 6 P3 — crown-jewel rotation. Replace ``const.phase_rewards`` + # with a sampled bank entry BEFORE the mission-bank multiplier so + # both transformations compose: bank picks "which subnet is the + # crown jewel this episode," then multiplier scales the per-component + # weights. Skipped on fast path when no bank is configured. + if self._phase_rewards_bank is not None: + pr_idx = jax.random.randint(key_pr, (), 0, self._phase_rewards_bank_size) + const = const.replace(phase_rewards=self._phase_rewards_bank[pr_idx]) + + # Phase 6 axis B — post-multiply ``const.phase_rewards`` by a sampled + # ``(LWF, ASF, RIA)`` triple. Mirrors the diversity-branch approach + # (topology.py:540–555): keeps SimulatorState shape unchanged and + # leaves rewards.py untouched. Skipped on the fast path when no + # bank was configured. + if self._mission_bank is not None: + mp_idx = jax.random.randint(key_mission, (), 0, self._mission_bank_size) + mp_multipliers = self._mission_bank[mp_idx] # (3,) float32 + phase_rewards = const.phase_rewards * mp_multipliers[None, None, :] + const = const.replace(phase_rewards=phase_rewards) + + # Phase 6 P2 — phase-boundary jitter. Replace ``const.phase_boundaries`` + # with a sampled bank entry. Phase transitions, allow-list flips, + # and per-phase reward emphasis all reindex against the sampled split. + if self._phase_boundary_bank is not None: + pb_idx = jax.random.randint(key_pb, (), 0, self._phase_boundary_bank_size) + const = const.replace(phase_boundaries=self._phase_boundary_bank[pb_idx]) + + return const def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], ScenarioEnvState]: const = self._select_const(key) diff --git a/src/jaxborg/evaluation/jax_env_factory.py b/src/jaxborg/evaluation/jax_env_factory.py index b2a18c4..b216222 100644 --- a/src/jaxborg/evaluation/jax_env_factory.py +++ b/src/jaxborg/evaluation/jax_env_factory.py @@ -65,6 +65,10 @@ def make_jax_env( topology_mode: str = "generative", training_mode: bool = False, topology_path: str | Path | Sequence[str | Path] | None = None, + mission_bank: Sequence[Sequence[float]] | None = None, + mission_bank_amplify: float = 1.0, + phase_boundary_bank: Sequence[Sequence[int]] | None = None, + phase_rewards_bank: Sequence | None = None, name: str | None = None, ) -> FsmRedCC4Env: if variant.resilience_roles: @@ -79,5 +83,9 @@ def make_jax_env( red_selector=selector, extras_factory=extras, op_zone_min_servers=variant.op_zone_servers, + mission_bank=mission_bank, + mission_bank_amplify=mission_bank_amplify, + phase_boundary_bank=phase_boundary_bank, + phase_rewards_bank=phase_rewards_bank, name=name, ) diff --git a/src/jaxborg/parity/fsm_red_env.py b/src/jaxborg/parity/fsm_red_env.py index bb91345..e7f61b6 100644 --- a/src/jaxborg/parity/fsm_red_env.py +++ b/src/jaxborg/parity/fsm_red_env.py @@ -95,6 +95,10 @@ def __init__( red_selector: RedSelector = fsm_selector, extras_factory: ExtrasFactory = _empty_extras_factory, op_zone_min_servers: int | None = None, + mission_bank: Sequence[Sequence[float]] | None = None, + mission_bank_amplify: float = 1.0, + phase_boundary_bank: Sequence[Sequence[int]] | None = None, + phase_rewards_bank: Sequence | None = None, name: Optional[str] = None, ): self._env = ScenarioEnv( @@ -103,6 +107,10 @@ def __init__( training_mode=training_mode, topology_path=topology_path, op_zone_min_servers=op_zone_min_servers, + mission_bank=mission_bank, + mission_bank_amplify=mission_bank_amplify, + phase_boundary_bank=phase_boundary_bank, + phase_rewards_bank=phase_rewards_bank, ) self._red_selector = red_selector self._extras_factory = extras_factory diff --git a/src/jaxborg/recipe.py b/src/jaxborg/recipe.py index 7f3b344..493ba5f 100644 --- a/src/jaxborg/recipe.py +++ b/src/jaxborg/recipe.py @@ -27,7 +27,7 @@ import yaml from jaxborg.scenarios.cc4.game_variant import GameVariant -from jaxborg.scenarios.cc4.game_variants import VARIANTS +from jaxborg.scenarios.cc4.game_variants import VARIANTS, variant_for_red REPO_ROOT = Path(__file__).resolve().parents[2] RECIPES_DIR = REPO_ROOT / "recipes" @@ -68,9 +68,26 @@ def train_variant(recipe: dict[str, Any]) -> GameVariant: def eval_variant(recipe: dict[str, Any]) -> GameVariant: + """Resolve the eval-time GameVariant. + + Precedence: + 1. ``eval.red`` (if set) — overrides the variant's red selector. The + base variant (``eval.variant`` or ``train.variant``) is used only to + decide ``resilience_roles`` for the fsm path; CIA-biased reds + (``cia_c`` / ``cia_i`` / ``cia_a`` / ``resilience``) carry their + own resilience_roles=True since their selectors require role tags. + This means setting ``eval.red: cia_a`` on a ``cc4_stock`` recipe + forces ``resilience_roles=True`` to keep the selector consistent. + 2. ``eval.variant`` — full variant name in ``VARIANTS``. + 3. ``train.variant`` — fallback if no eval section is configured. + """ eval_cfg = recipe.get("eval") or {} - name = eval_cfg.get("variant") or recipe.get("train", {}).get("variant", "cc4_stock") - return VARIANTS[name] + base_name = eval_cfg.get("variant") or recipe.get("train", {}).get("variant", "cc4_stock") + base = VARIANTS[base_name] + red = eval_cfg.get("red") + if red is None: + return base + return variant_for_red(red, resilience_roles=base.resilience_roles) def resolve_eval_variant( @@ -98,6 +115,78 @@ def resolve_eval_variant( return default if default is not None else CC4_STOCK +def _project_mission_bank(train: dict[str, Any]) -> list[list[float]] | None: + """Project ``train.mission_bank`` to a list of (LWF, ASF, RIA) triples. + + Returns ``None`` when the recipe omits ``mission_bank`` or supplies an + empty list — both signal "no per-reset variation" (legacy behavior). + """ + bank = train.get("mission_bank") + if bank is None: + return None + out: list[list[float]] = [] + for entry in bank: + triple = list(entry) + if len(triple) != 3: + raise ValueError(f"train.mission_bank entries must be (LWF, ASF, RIA) triples; got {entry!r}") + out.append([float(x) for x in triple]) + if not out: + return None + return out + + +def _resolve_topology_bank(train: dict[str, Any]) -> tuple[Path, ...]: + """Resolve ``train.topology_bank`` paths against the repo root.""" + bank = train.get("topology_bank") or () + if isinstance(bank, (str, Path)): + bank = [bank] + resolved: list[Path] = [] + for entry in bank: + p = Path(entry) + if not p.is_absolute(): + p = REPO_ROOT / p + resolved.append(p) + return tuple(resolved) + + +def _project_phase_boundary_bank(train: dict[str, Any]) -> list[list[int]] | None: + """Project ``train.phase_boundary_bank`` to a list of int triples.""" + bank = train.get("phase_boundary_bank") + if bank is None: + return None + out: list[list[int]] = [] + for entry in bank: + triple = list(entry) + if len(triple) != 3: + raise ValueError( + "train.phase_boundary_bank entries must be 3-tuples " + f"(phase0_start, phase1_start, phase2_start); got {entry!r}" + ) + out.append([int(x) for x in triple]) + if not out: + return None + return out + + +def _project_phase_rewards_bank(train: dict[str, Any]): + """Resolve ``train.phase_rewards_bank`` to a stacked numpy bank or None. + + Recipe accepts either a bool (``true`` → use the canonical 6-entry + crown-jewel rotation bank from topology_numpy) or a list of explicit + ``(MISSION_PHASES, NUM_SUBNETS, 3)`` arrays for custom banks. Most + callers want the bool form. + """ + val = train.get("phase_rewards_bank") + if val is None or val is False: + return None + if val is True: + from jaxborg.scenarios.cc4.topology_numpy import get_phase_rewards_bank + + return get_phase_rewards_bank() + # Explicit list of arrays — pass through (caller stacks). + return val + + def project_jax(recipe: dict[str, Any]) -> dict[str, Any]: """Flatten recipe into the dict shape ippo_jax.py's config expects.""" core = recipe["core"] @@ -131,6 +220,11 @@ def project_jax(recipe: dict[str, Any]) -> dict[str, Any]: "EVAL_VARIANT": eval_variant(recipe), "TRAINING_MODE": True, "MLFLOW_ENABLED": True, + "MISSION_BANK": _project_mission_bank(train), + "MISSION_BANK_AMPLIFY": float(train.get("mission_bank_amplify", 1.0)), + "TOPOLOGY_BANK": _resolve_topology_bank(train), + "PHASE_BOUNDARY_BANK": _project_phase_boundary_bank(train), + "PHASE_REWARDS_BANK": _project_phase_rewards_bank(train), } diff --git a/src/jaxborg/scenarios/cc4/topology.py b/src/jaxborg/scenarios/cc4/topology.py index 52209dd..14875ec 100644 --- a/src/jaxborg/scenarios/cc4/topology.py +++ b/src/jaxborg/scenarios/cc4/topology.py @@ -224,7 +224,7 @@ def build_topology( num_steps: int = 500, *, training_mode: bool = False, - op_zone_min_servers: int | None = None, + op_zone_min_servers: int | tuple[int, int] | None = None, ) -> SimulatorConst: """Build CC4 topology in pure JAX — JIT-compatible. @@ -232,6 +232,13 @@ def build_topology( 1 router + random server hosts (1-6) + random user hosts (3-10). Internet subnet gets 1 host (root_internet_host_0). + ``op_zone_min_servers`` controls the operational-zone server floor. + Pass an int to force both OPS-A and OPS-B to the same value (legacy + behavior). Pass a 2-tuple ``(a_floor, b_floor)`` to set them + independently — used by the topology bank builder to produce totals + that aren't multiples of 2 (e.g. 3 = 1+2 for balanced AUTH/DB/WEB + role assignment across three resilience candidates). + Host indices follow alphabetical hostname ordering (same as build_const_from_cyborg): subnets ordered by CYBORG_SUBNET_SUFFIX, within each subnet: router < servers < users. """ @@ -241,8 +248,15 @@ def build_topology( n_users = jax.random.randint(k_users, (8,), 3, 11) random_n = jax.random.randint(k_servers, (8,), 1, 7) if op_zone_min_servers is not None: - op_zone_alpha = jnp.array([False, False, False, True, True, False, False, False]) - n_servers = jnp.where(op_zone_alpha, jnp.int32(op_zone_min_servers), random_n) + if isinstance(op_zone_min_servers, tuple): + a_floor, b_floor = op_zone_min_servers + else: + a_floor = b_floor = int(op_zone_min_servers) + # alpha-order positions 3 and 4 are OPERATIONAL_ZONE_A / OPERATIONAL_ZONE_B. + floor_per_alpha = jnp.array( + [-1, -1, -1, int(a_floor), int(b_floor), -1, -1, -1], dtype=jnp.int32 + ) + n_servers = jnp.where(floor_per_alpha >= 0, floor_per_alpha, random_n) else: n_servers = random_n diff --git a/src/jaxborg/scenarios/cc4/topology_numpy.py b/src/jaxborg/scenarios/cc4/topology_numpy.py index 5775f29..30a4f17 100644 --- a/src/jaxborg/scenarios/cc4/topology_numpy.py +++ b/src/jaxborg/scenarios/cc4/topology_numpy.py @@ -19,6 +19,147 @@ SUBNET_NAMES, ) +# Phase 6 axis B (mission-profile multiplier bank). Per ``env.reset``, sample +# one ``(LWF, ASF, RIA)`` triple from this bank to scale ``const.phase_rewards``. +# Bank[0] is the default ``(1, 1, 1)`` — when the bank is reduced to that single +# entry (or disabled), behavior matches legacy CC4 exactly. +# +# The default 4-entry bank uses 3× amplification on one CIA component at a time +# (Phase 6 plan §"Axis B"), with a 10× fallback exposed via the +# ``mission_bank_amplify`` recipe knob. ``mission_bank_amplify`` multiplies the +# *entire* sampled triple element-wise — applied after sampling, so amplify=10 +# with a (1, 3, 1) entry yields (10, 30, 10), not (1, 30, 1). This keeps the +# implementation simple (no special-casing of the off-axis 1.0 entries). +# +# Order: ``LWF=0, ASF=1, RIA=2`` per ``src/jaxborg/rewards.py``. +MISSION_PROFILE_MULTIPLIERS: tuple[tuple[float, float, float], ...] = ( + # (LWF, ASF, RIA) + (1.0, 1.0, 1.0), # default — balanced + (3.0, 1.0, 1.0), # productivity-heavy: amplify LWF + (1.0, 3.0, 1.0), # availability-heavy: amplify ASF + (1.0, 1.0, 3.0), # CI-heavy: amplify RIA +) +NUM_MISSION_PROFILES = len(MISSION_PROFILE_MULTIPLIERS) + +# Anti-correlated bank: each non-baseline entry boosts TWO components, never +# just one — so a "boost the loud component" memorization fails because every +# component is sometimes loud and sometimes quiet. This is the answer to the +# Phase 6 Test 1 critique that Axis B's σ-ratio PASS was partly mechanical +# scaling: anti-correlated profiles can't be solved by scaling-up one channel. +MISSION_PROFILE_ANTI_CORR: tuple[tuple[float, float, float], ...] = ( + (1.0, 1.0, 1.0), # baseline so legacy default behavior is reachable + (3.0, 3.0, 1.0), # boost LWF + ASF, dampen relative weight of RIA + (1.0, 3.0, 3.0), # boost ASF + RIA + (3.0, 1.0, 3.0), # boost LWF + RIA +) + + +def get_mission_profile_multipliers() -> np.ndarray: + """(NUM_MISSION_PROFILES, 3) float32 multipliers in (LWF, ASF, RIA) order.""" + return np.asarray(MISSION_PROFILE_MULTIPLIERS, dtype=np.float32) + + +def get_mission_profile_anti_corr() -> np.ndarray: + """(4, 3) float32 anti-correlated multipliers in (LWF, ASF, RIA) order.""" + return np.asarray(MISSION_PROFILE_ANTI_CORR, dtype=np.float32) + + +# Phase-boundary jitter bank. Each entry is a 3-tuple of step indices +# ``(phase0_start, phase1_start, phase2_start)``; phase0 always starts at 0. +# Assumes 500-step episodes (the canonical CC4 episode length); shorter banks +# scale linearly. Boundaries control when the allow-list flips and when +# ``phase_rewards`` switches its emphasis between OPS-A (phase 1) and OPS-B +# (phase 2), so jittering these breaks "deploy decoys at step 167" memorization. +PHASE_BOUNDARIES_BANK: tuple[tuple[int, int, int], ...] = ( + (0, 167, 333), # canonical CC4 split (3 ~equal phases) + (0, 100, 300), # short setup, balanced mid+late + (0, 200, 400), # long setup, short late + (0, 150, 250), # short mid-phase, late starts at 250 +) + + +def get_phase_boundaries_bank() -> np.ndarray: + """(N, 3) int32 phase-boundary triples for 500-step episodes.""" + return np.asarray(PHASE_BOUNDARIES_BANK, dtype=np.int32) + + +def _build_phase_rewards_bank() -> np.ndarray: + """Build the crown-jewel rotation bank. + + Each entry is a ``(MISSION_PHASES, NUM_SUBNETS, 3)`` phase_rewards array; + bank[0] is the canonical CC4 table (matches ``_build_phase_rewards`` + exactly), so a recipe with ``phase_rewards_bank: true`` and a 1-entry bank + reproduces legacy behavior. Remaining entries permute *which subnet is + high-value in which phase*: the canonical table emphasizes OPS_A in + phase 1 and OPS_B in phase 2; bank entries 1+ rotate that across other + operational/administrative subnets so the same physical topology + generates different reward gradients per episode. + + The policy must read which-subnet-is-which from the observation rather + than memorizing "phase 1 → focus on subnet index 3" — directly addresses + the Phase 6 Test 1 finding that Axis A's σ-ratio was null because subnet + *labels* were stable across the topology bank. + """ + canonical = _build_phase_rewards() + S = SUBNET_IDS + OA, OB = S["OPERATIONAL_ZONE_A"], S["OPERATIONAL_ZONE_B"] + RA, RB = S["RESTRICTED_ZONE_A"], S["RESTRICTED_ZONE_B"] + ADMIN, OFFICE = S["ADMIN_NETWORK"], S["OFFICE_NETWORK"] + + bank = [canonical] + + # Entry 1: swap OPS_A ↔ OPS_B in phases 1 and 2 (and their RZ pairs). + # Phase 1 now emphasizes OPS_B; phase 2 emphasizes OPS_A. The "primary + # mission target" rotates between episodes. + swap_AB = canonical.copy() + swap_AB[1, OA] = canonical[1, OB] + swap_AB[1, OB] = canonical[1, OA] + swap_AB[1, RA] = canonical[1, RB] + swap_AB[1, RB] = canonical[1, RA] + swap_AB[2, OA] = canonical[2, OB] + swap_AB[2, OB] = canonical[2, OA] + swap_AB[2, RA] = canonical[2, RB] + swap_AB[2, RB] = canonical[2, RA] + bank.append(swap_AB) + + # Entry 2: emphasize ADMIN_NETWORK as a phase-1 priority (analyst console + # network — same shape topology, different "what blue is told to protect"). + admin_priority = canonical.copy() + # Boost ADMIN per-component weight in phase 1 to match the OPS_A intensity. + admin_priority[1, ADMIN] = np.array([-5, -2, -5], dtype=np.float32) + bank.append(admin_priority) + + # Entry 3: emphasize OFFICE_NETWORK in phase 2 (insider-threat scenario). + office_priority = canonical.copy() + office_priority[2, OFFICE] = np.array([-5, -2, -5], dtype=np.float32) + bank.append(office_priority) + + # Entry 4: phase 1 protects BOTH OPS_A and OPS_B simultaneously + # (heightened-alert scenario; no rotation, just intensity in phase 1). + both_ops = canonical.copy() + both_ops[1, OB] = canonical[1, OA] # OPS_B gets the OPS_A treatment too + bank.append(both_ops) + + # Entry 5: full rotation — phase 1 emphasizes OPS_B + ADMIN; phase 2 + # emphasizes OPS_A + OFFICE. Tests whether the policy can adapt to a + # full reframing of mission priority structure. + full_rotate = canonical.copy() + full_rotate[1, OA] = canonical[1, OB] + full_rotate[1, OB] = canonical[1, OA] + full_rotate[1, ADMIN] = np.array([-5, -2, -5], dtype=np.float32) + full_rotate[2, OA] = canonical[2, OB] + full_rotate[2, OB] = canonical[2, OA] + full_rotate[2, OFFICE] = np.array([-5, -2, -5], dtype=np.float32) + bank.append(full_rotate) + + return np.stack(bank, axis=0) + + +def get_phase_rewards_bank() -> np.ndarray: + """(N, MISSION_PHASES, NUM_SUBNETS, 3) float32 crown-jewel rotation bank.""" + return _build_phase_rewards_bank() + + _ROUTER_LINKS = { "INTERNET": [ "RESTRICTED_ZONE_A", diff --git a/tests/test_eval_red_override.py b/tests/test_eval_red_override.py new file mode 100644 index 0000000..fd78802 --- /dev/null +++ b/tests/test_eval_red_override.py @@ -0,0 +1,219 @@ +"""Tests for the ``eval.red`` override hook (Phase 6, Stream S3). + +Covers: +- ``eval.red`` override in :func:`jaxborg.recipe.eval_variant`. +- Resilience-role coupling: setting ``eval.red: cia_a`` on a ``cc4_stock`` + recipe must force ``resilience_roles=True`` because the cia_a selector + requires role tags. +- Falls back to ``eval.variant`` when ``eval.red`` is null/missing. +- ``--eval-red`` CLI flag on ``scripts/eval/eval_recipe.py`` overrides + whatever the loaded recipe says (CLI > recipe). +""" + +from __future__ import annotations + +import sys +import types +from pathlib import Path + +from jaxborg.recipe import eval_variant + + +def _make_recipe(*, eval_section: dict | None = None, train_variant: str = "cc4_stock") -> dict: + """Build a minimal recipe dict that satisfies eval_variant.""" + recipe: dict = { + "meta": {"name": "test"}, + "algorithm": "ippo", + "core": {"lr": 3e-4, "gamma": 0.99, "gae_lambda": 0.95}, + "arch": {"name": "mlp"}, + "train": { + "variant": train_variant, + "episode_length": 500, + "total_timesteps": 1000, + }, + } + if eval_section is not None: + recipe["eval"] = eval_section + return recipe + + +def test_eval_red_cia_c_with_cia_resilience_variant(): + """Override to cia_c on a cia_resilience base preserves resilience_roles.""" + recipe = _make_recipe(eval_section={"variant": "cia_resilience", "red": "cia_c"}) + v = eval_variant(recipe) + assert v.red_agent == "c" + assert v.resilience_roles is True + + +def test_eval_red_fsm_on_cc4_stock_drops_resilience_roles(): + """Override to fsm on a cc4_stock base keeps resilience_roles off.""" + recipe = _make_recipe(eval_section={"variant": "cc4_stock", "red": "fsm"}) + v = eval_variant(recipe) + assert v.red_agent == "finite_state" + assert v.resilience_roles is False + + +def test_eval_red_cia_a_on_cc4_stock_forces_resilience_roles(): + """Override to cia_a on a cc4_stock base must force resilience_roles=True. + + The cia_a selector requires per-host role assignments (AUTH/DB/WEB), so + the variant_for_red("cia_a", ...) helper always returns CIA_A which has + resilience_roles=True regardless of the base. This is the deliberate + corner-case behavior documented on eval_variant(). + """ + recipe = _make_recipe(eval_section={"variant": "cc4_stock", "red": "cia_a"}) + v = eval_variant(recipe) + assert v.red_agent == "a" + assert v.resilience_roles is True + + +def test_eval_red_missing_falls_back_to_variant(): + """No eval.red → existing eval.variant resolution (regression).""" + recipe = _make_recipe(eval_section={"variant": "cia_resilience"}) + v = eval_variant(recipe) + assert v.name == "cia_resilience" + assert v.red_agent == "resilience" + assert v.resilience_roles is True + + +def test_eval_red_null_falls_back_to_variant(): + """Explicit null eval.red is treated as 'unset'.""" + recipe = _make_recipe(eval_section={"variant": "cc4_stock", "red": None}) + v = eval_variant(recipe) + assert v.name == "cc4_stock" + assert v.red_agent == "finite_state" + + +def test_eval_section_missing_falls_back_to_train_variant(): + """No eval section → train.variant drives the eval variant.""" + recipe = _make_recipe(eval_section=None, train_variant="cia_resilience") + v = eval_variant(recipe) + assert v.name == "cia_resilience" + + +def test_eval_red_resilience_keeps_resilience_roles(): + """Override to 'resilience' returns CIA_RESILIENCE.""" + recipe = _make_recipe(eval_section={"variant": "cc4_stock", "red": "resilience"}) + v = eval_variant(recipe) + assert v.red_agent == "resilience" + assert v.resilience_roles is True + + +def test_eval_red_sleep_on_cc4_stock(): + """sleep red is a thin variant on top of cc4_stock.""" + recipe = _make_recipe(eval_section={"variant": "cc4_stock", "red": "sleep"}) + v = eval_variant(recipe) + assert v.red_agent == "sleep" + assert v.resilience_roles is False + + +def test_eval_red_cli_flag_overrides_recipe(monkeypatch, tmp_path): + """--eval-red CLI flag must override whatever the loaded recipe sidecar says. + + We stub out the heavy CybORG/JAX runner imports and capture the variant + that ``eval_recipe.main`` actually resolves, exercising the CLI > recipe + precedence end to end. + """ + repo_root = Path(__file__).resolve().parents[1] + script_path = repo_root / "scripts" / "eval" / "eval_recipe.py" + spec_root = str(repo_root / "scripts" / "eval") + if spec_root not in sys.path: + sys.path.insert(0, spec_root) + + # Recipe sidecar says cia_c; CLI says cia_a. CLI must win. + sidecar_recipe = _make_recipe(eval_section={"variant": "cia_resilience", "red": "cia_c"}) + sidecar_recipe["meta"]["name"] = "test_cli_override" + + captured: dict = {} + + def fake_evaluate_on_cyborg(model_path, *, variant, seeds, episodes_per_seed, deterministic, workers): + captured["variant"] = variant + return ([0.0], [0]) + + fake_runner = types.ModuleType("jaxborg.evaluation.cyborg_runner") + fake_runner.evaluate_on_cyborg = fake_evaluate_on_cyborg + + fake_model = tmp_path / "model.pt" + fake_model.write_bytes(b"") + + # Patch read_sidecar to return our test recipe. + monkeypatch.setattr("jaxborg.checkpoint.read_sidecar", lambda p: dict(sidecar_recipe)) + # Stub the CybORG runner module entirely so import doesn't pull CybORG. + monkeypatch.setitem(sys.modules, "jaxborg.evaluation.cyborg_runner", fake_runner) + # Skip MLflow attach. + monkeypatch.setattr("jaxborg.mlflow_setup.attach_eval_metrics", lambda *a, **kw: None) + + # Run the CLI with --eval-red cia_a. + argv = [ + "eval_recipe.py", + "--model", + str(fake_model), + "--episodes", + "1", + "--seeds", + "0", + "--eval-red", + "cia_a", + "--output", + str(tmp_path / "out.jsonl"), + ] + monkeypatch.setattr(sys, "argv", argv) + + # Load the script as a module so `main()` is callable. + import importlib.util + + spec = importlib.util.spec_from_file_location("eval_recipe_under_test", script_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module.main() + + variant = captured["variant"] + assert variant.red_agent == "a", f"expected red_agent='a' from CLI override, got {variant.red_agent}" + assert variant.resilience_roles is True + + +def test_eval_red_cli_flag_unset_uses_recipe_red(monkeypatch, tmp_path): + """When --eval-red is not passed, recipe eval.red still applies.""" + repo_root = Path(__file__).resolve().parents[1] + script_path = repo_root / "scripts" / "eval" / "eval_recipe.py" + + sidecar_recipe = _make_recipe(eval_section={"variant": "cia_resilience", "red": "cia_i"}) + sidecar_recipe["meta"]["name"] = "test_no_cli" + + captured: dict = {} + + def fake_evaluate_on_cyborg(model_path, *, variant, seeds, episodes_per_seed, deterministic, workers): + captured["variant"] = variant + return ([0.0], [0]) + + fake_runner = types.ModuleType("jaxborg.evaluation.cyborg_runner") + fake_runner.evaluate_on_cyborg = fake_evaluate_on_cyborg + + fake_model = tmp_path / "model.pt" + fake_model.write_bytes(b"") + + monkeypatch.setattr("jaxborg.checkpoint.read_sidecar", lambda p: dict(sidecar_recipe)) + monkeypatch.setitem(sys.modules, "jaxborg.evaluation.cyborg_runner", fake_runner) + monkeypatch.setattr("jaxborg.mlflow_setup.attach_eval_metrics", lambda *a, **kw: None) + + argv = [ + "eval_recipe.py", + "--model", + str(fake_model), + "--episodes", + "1", + "--seeds", + "0", + "--output", + str(tmp_path / "out.jsonl"), + ] + monkeypatch.setattr(sys, "argv", argv) + + import importlib.util + + spec = importlib.util.spec_from_file_location("eval_recipe_under_test_2", script_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module.main() + + assert captured["variant"].red_agent == "i" diff --git a/tests/test_mission_bank_sampling.py b/tests/test_mission_bank_sampling.py new file mode 100644 index 0000000..a505004 --- /dev/null +++ b/tests/test_mission_bank_sampling.py @@ -0,0 +1,307 @@ +"""Phase 6 stream S2 — per-reset mission-profile multiplier bank. + +Tests the ``mission_bank`` / ``mission_bank_amplify`` plumbing on +:class:`jaxborg.env.ScenarioEnv`. Per the Phase 6 plan (axis B) the bank is +sampled at reset and post-multiplies ``const.phase_rewards`` — there is no +``state.mission_multipliers`` field, and ``rewards.py`` is unchanged. + +Asserts: + +1. **Determinism** — same PRNG key → same multiplier triple sampled. +2. **Uniformity** — across 10000 keys each of the 4 default-bank entries is + sampled within 5% of uniform (chi-square at α=0.01). +3. **Single-entry bank applies the triple** — with ``mission_bank=[(1, 3, 1)]`` + the resulting ``const.phase_rewards`` equals 3× the baseline ASF channel + exactly (rewards.py would then produce 3× ASF reward at any step where ASF + triggers). +4. **Amplify scales the entire triple** — ``mission_bank_amplify=10.0`` with + bank ``[(1, 3, 1)]`` yields a (10, 30, 10) effective multiplier, not + (1, 30, 1). Documented this way for simplicity (no special-casing). +5. **Empty / None bank is the legacy fast path** — ``const.phase_rewards`` + is byte-identical to the baseline (no bank configured). +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jaxborg.env import ScenarioEnv +from jaxborg.scenarios.cc4.topology_numpy import ( + MISSION_PROFILE_MULTIPLIERS, + NUM_MISSION_PROFILES, + get_mission_profile_multipliers, +) + +# Use a short num_steps so reset/topology build is cheap. +NUM_STEPS = 10 + + +def _make_env(mission_bank=None, mission_bank_amplify=1.0): + return ScenarioEnv( + num_steps=NUM_STEPS, + mission_bank=mission_bank, + mission_bank_amplify=mission_bank_amplify, + ) + + +def _sample_const(env, key): + return env._select_const(key) + + +# --------------------------------------------------------------------------- +# Default bank metadata + + +class TestMissionProfileTable: + def test_default_bank_has_four_entries(self): + assert NUM_MISSION_PROFILES == 4 + assert len(MISSION_PROFILE_MULTIPLIERS) == 4 + + def test_default_bank_first_entry_is_unit(self): + assert MISSION_PROFILE_MULTIPLIERS[0] == (1.0, 1.0, 1.0) + + def test_get_mission_profile_multipliers_shape(self): + arr = get_mission_profile_multipliers() + assert arr.shape == (NUM_MISSION_PROFILES, 3) + assert arr.dtype == np.float32 + np.testing.assert_array_equal(arr[0], np.array([1.0, 1.0, 1.0], dtype=np.float32)) + + +# --------------------------------------------------------------------------- +# Determinism + uniformity of the sampled index + + +class TestSamplingDeterminism: + def test_same_key_same_multiplier(self): + bank = [[1.0, 1.0, 1.0], [3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]] + env = _make_env(mission_bank=bank) + key = jax.random.PRNGKey(7) + c1 = _sample_const(env, key) + c2 = _sample_const(env, key) + np.testing.assert_array_equal(np.asarray(c1.phase_rewards), np.asarray(c2.phase_rewards)) + + def test_different_keys_can_yield_different_multipliers(self): + bank = [[1.0, 1.0, 1.0], [3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]] + env = _make_env(mission_bank=bank) + + # Build a baseline (no-bank) const so we can recover the per-key + # multiplier triple by dividing through the LWF/ASF/RIA channels at any + # nonzero (phase, subnet) cell. + env_baseline = _make_env(mission_bank=None) + + seen = set() + for s in range(20): + key = jax.random.PRNGKey(s) + c_bank = _sample_const(env, key) + c_base = _sample_const(env_baseline, key) + triple = _recover_multiplier(c_base.phase_rewards, c_bank.phase_rewards) + seen.add(tuple(round(float(x), 4) for x in triple)) + # With 4 entries × 20 draws we expect to hit at least 2 distinct entries. + assert len(seen) >= 2 + + +def _recover_multiplier(baseline_pr: jnp.ndarray, bank_pr: jnp.ndarray) -> tuple[float, float, float]: + """Recover the (LWF, ASF, RIA) triple applied to baseline_pr. + + For each component we pick a (phase, subnet) cell where the baseline weight + is nonzero, then ratio. Returns floats. + """ + base = np.asarray(baseline_pr) + bank = np.asarray(bank_pr) + triple = [] + for c in range(3): + base_c = base[..., c] + bank_c = bank[..., c] + nz = base_c != 0.0 + if not nz.any(): + triple.append(1.0) + continue + ratios = bank_c[nz] / base_c[nz] + # All ratios for a given component must be identical (we multiply by + # a scalar) — assert that for safety, then take the first. + assert np.allclose(ratios, ratios[0], atol=1e-5), f"non-uniform ratio for component {c}: {ratios}" + triple.append(float(ratios[0])) + return tuple(triple) # type: ignore[return-value] + + +class TestSamplingUniformity: + """Across 10k keys each of 4 default-bank entries sampled within 5% of uniform.""" + + def test_uniform_chisquare(self): + bank = [[1.0, 1.0, 1.0], [3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]] + # Build env to exercise the construction path (validates bank shape). + _make_env(mission_bank=bank) + # Sampling the *index* directly is much cheaper than building the whole + # const for each key; mirror the env's exact per-key derivation. + n = 10_000 + + @jax.jit + def sample_idx(key): + _, key_mission = jax.random.split(key) + return jax.random.randint(key_mission, (), 0, len(bank)) + + keys = jax.random.split(jax.random.PRNGKey(0), n) + idxs = jax.vmap(sample_idx)(keys) + idxs_np = np.asarray(idxs) + counts = np.bincount(idxs_np, minlength=len(bank)) + expected = n / len(bank) + + # 5% relative tolerance per entry. + rel_err = np.abs(counts - expected) / expected + assert (rel_err < 0.05).all(), f"counts={counts.tolist()} expected={expected}, rel_err={rel_err.tolist()}" + + # Chi-square at α=0.01 — 3 dof, critical value = 11.345. + chi2 = float(((counts - expected) ** 2 / expected).sum()) + assert chi2 < 11.345, f"chi2={chi2} ≥ 11.345 (4-bin uniform, α=0.01)" + + +# --------------------------------------------------------------------------- +# phase_rewards multiplication semantics + + +class TestPhaseRewardsScaling: + """Verify that the bank entry actually scales const.phase_rewards. + + rewards.py reads ``const.phase_rewards[phase, subnet, channel]`` directly, + so scaling a channel here is equivalent to scaling that channel's per-step + contribution to the total reward (proportional to the number of triggered + hosts on that subnet, but the *per-host weight* scales linearly). + """ + + def test_single_entry_bank_scales_asf_channel(self): + """bank=[(1, 3, 1)] → ASF channel exactly 3× baseline; LWF & RIA unchanged.""" + env_base = _make_env(mission_bank=None) + env_bank = _make_env(mission_bank=[[1.0, 3.0, 1.0]]) + + key = jax.random.PRNGKey(123) + c_base = _sample_const(env_base, key) + c_bank = _sample_const(env_bank, key) + + base_pr = np.asarray(c_base.phase_rewards) + bank_pr = np.asarray(c_bank.phase_rewards) + + np.testing.assert_allclose(bank_pr[..., 0], base_pr[..., 0] * 1.0, rtol=0, atol=1e-6) + np.testing.assert_allclose(bank_pr[..., 1], base_pr[..., 1] * 3.0, rtol=0, atol=1e-6) + np.testing.assert_allclose(bank_pr[..., 2], base_pr[..., 2] * 1.0, rtol=0, atol=1e-6) + + def test_amplify_multiplies_entire_triple(self): + """``mission_bank_amplify=10`` × bank ``[(1, 3, 1)]`` → effective (10, 30, 10). + + Documented behavior: amplify scales the entire sampled triple element-wise. + amplify=10 with (1, 3, 1) yields (10, 30, 10), NOT (1, 30, 1). + """ + env_base = _make_env(mission_bank=None) + env_bank = _make_env(mission_bank=[[1.0, 3.0, 1.0]], mission_bank_amplify=10.0) + + key = jax.random.PRNGKey(99) + c_base = _sample_const(env_base, key) + c_bank = _sample_const(env_bank, key) + + base_pr = np.asarray(c_base.phase_rewards) + bank_pr = np.asarray(c_bank.phase_rewards) + + np.testing.assert_allclose(bank_pr[..., 0], base_pr[..., 0] * 10.0, rtol=0, atol=1e-5) + np.testing.assert_allclose(bank_pr[..., 1], base_pr[..., 1] * 30.0, rtol=0, atol=1e-5) + np.testing.assert_allclose(bank_pr[..., 2], base_pr[..., 2] * 10.0, rtol=0, atol=1e-5) + + +class TestEmptyBankIsLegacyFastPath: + """When ``mission_bank`` is None or empty, ``const.phase_rewards`` is unchanged.""" + + @pytest.mark.parametrize("bank", [None, []]) + def test_no_bank_no_change(self, bank): + env_base = _make_env(mission_bank=None) + env_other = _make_env(mission_bank=bank) + # Internal flags should reflect "disabled." + assert env_other._mission_bank is None + assert env_other._mission_bank_size == 0 + + key = jax.random.PRNGKey(2026) + c_base = _sample_const(env_base, key) + c_other = _sample_const(env_other, key) + + np.testing.assert_array_equal(np.asarray(c_base.phase_rewards), np.asarray(c_other.phase_rewards)) + + def test_amplify_alone_without_bank_is_noop(self): + """``mission_bank=None`` ignores ``mission_bank_amplify`` (no scaling).""" + env_base = _make_env(mission_bank=None) + env_amp = _make_env(mission_bank=None, mission_bank_amplify=10.0) + + key = jax.random.PRNGKey(2027) + c_base = _sample_const(env_base, key) + c_amp = _sample_const(env_amp, key) + + np.testing.assert_array_equal(np.asarray(c_base.phase_rewards), np.asarray(c_amp.phase_rewards)) + + +# --------------------------------------------------------------------------- +# Construction validation + + +class TestBankShapeValidation: + def test_rejects_non_triple_entries(self): + with pytest.raises(ValueError, match="3-tuples"): + _make_env(mission_bank=[[1.0, 1.0]]) + + def test_accepts_default_4_entry_bank(self): + bank = [list(t) for t in MISSION_PROFILE_MULTIPLIERS] + env = _make_env(mission_bank=bank) + assert env._mission_bank_size == 4 + assert env._mission_bank.shape == (4, 3) + + +# --------------------------------------------------------------------------- +# Recipe projection wiring + + +class TestRecipeProjection: + """``project_jax`` reads ``train.mission_bank`` and ``train.mission_bank_amplify``.""" + + def _recipe(self, **train_extra): + train = { + "variant": "cc4_stock", + "episode_length": 10, + "total_timesteps": 100, + } + train.update(train_extra) + return { + "meta": {"name": "test"}, + "algorithm": {"name": "ippo_jax"}, + "core": {"lr": 3e-4, "gamma": 0.99, "gae_lambda": 0.95}, + "arch": {"name": "mlp"}, + "train": train, + } + + def test_default_no_bank(self): + from jaxborg.recipe import project_jax + + cfg = project_jax(self._recipe()) + assert cfg["MISSION_BANK"] is None + assert cfg["MISSION_BANK_AMPLIFY"] == 1.0 + + def test_projects_bank_and_amplify(self): + from jaxborg.recipe import project_jax + + cfg = project_jax( + self._recipe( + mission_bank=[[1.0, 1.0, 1.0], [3.0, 1.0, 1.0]], + mission_bank_amplify=10.0, + ) + ) + assert cfg["MISSION_BANK"] == [[1.0, 1.0, 1.0], [3.0, 1.0, 1.0]] + assert cfg["MISSION_BANK_AMPLIFY"] == 10.0 + + def test_empty_bank_projects_to_none(self): + from jaxborg.recipe import project_jax + + cfg = project_jax(self._recipe(mission_bank=[])) + assert cfg["MISSION_BANK"] is None + + def test_bad_triple_rejected(self): + from jaxborg.recipe import project_jax + + with pytest.raises(ValueError, match="LWF, ASF, RIA"): + project_jax(self._recipe(mission_bank=[[1.0, 2.0]])) diff --git a/tests/test_phase6_recipes.py b/tests/test_phase6_recipes.py new file mode 100644 index 0000000..22d2209 --- /dev/null +++ b/tests/test_phase6_recipes.py @@ -0,0 +1,99 @@ +"""Phase 6 / S4 — sanity tests for the 2×2 factorial recipe matrix. + +Validates ``recipes/cec_phase6_{C00,C10,C01,C11}.yaml``: + +* every recipe loads and projects to a JAX config without raising; +* ``TOPOLOGY_BANK`` length is 0 (no bank) or 16 (full bank), per arm; +* every path in a configured ``TOPOLOGY_BANK`` exists on disk; +* ``MISSION_BANK`` is ``None`` (no bank) or matches the plan's 4-entry + default ``[(1,1,1), (3,1,1), (1,3,1), (1,1,3)]`` with amplify=1.0. + +The 2×2 factorial (per plans/jax/cc4/cec/cec-phase6-plan.md): + + | arm | topology bank | mission bank | + | --- | ------------- | ------------ | + | C00 | (none) | (none) | + | C10 | 16 snapshots | (none) | + | C01 | (none) | 4-entry | + | C11 | 16 snapshots | 4-entry | +""" + +from __future__ import annotations + +import pytest + +from jaxborg.recipe import load, project_jax + +EXPECTED_MISSION_BANK = [[1.0, 1.0, 1.0], [3.0, 3.0, 1.0], [1.0, 3.0, 3.0], [3.0, 1.0, 3.0]] +# Test 2 collapsed to just the control + full-cocktail arm: C00 (no banks) and +# C11 (topology + mission anti-corr + phase-boundary + phase-rewards rotation +# all active). The original 2×2 factorial arms C01/C10 were dropped after +# the σ-ratio policy-mediation finding made the per-axis ablations +# uninformative without training-time exposure. +ARMS = { + "cec_phase6_C00": (0, None), + "cec_phase6_C11": (16, EXPECTED_MISSION_BANK), +} + + +@pytest.fixture(scope="module", params=sorted(ARMS)) +def arm(request): + name = request.param + recipe = load(name) + cfg = project_jax(recipe) + return name, recipe, cfg + + +def test_recipe_loads(arm): + name, recipe, _cfg = arm + assert recipe["meta"]["name"] == name + # Plan citation is required so future readers can find the rationale. + blob = " ".join(str(recipe["meta"].get(k, "")) for k in ("source", "notes")) + assert "cec-phase6-plan.md" in blob + + +def test_train_variant_is_cc4_stock(arm): + _name, recipe, _cfg = arm + # CEC-faithful: training partner fixed (cc4_stock variant's red = fsm). + assert recipe["train"]["variant"] == "cc4_stock" + + +def test_total_timesteps_3m(arm): + _name, recipe, _cfg = arm + assert int(recipe["train"]["total_timesteps"]) == 3_000_000 + + +def test_topology_bank_length(arm): + name, _recipe, cfg = arm + expected_len, _ = ARMS[name] + assert len(cfg["TOPOLOGY_BANK"]) == expected_len, ( + f"{name}: expected TOPOLOGY_BANK length {expected_len}, got {len(cfg['TOPOLOGY_BANK'])}" + ) + + +def test_topology_bank_paths_exist(arm): + name, _recipe, cfg = arm + expected_len, _ = ARMS[name] + if expected_len == 0: + pytest.skip("no topology bank configured for this arm") + for p in cfg["TOPOLOGY_BANK"]: + assert p.exists(), f"{name}: topology snapshot missing on disk: {p}" + + +def test_mission_bank_matches(arm): + name, _recipe, cfg = arm + _, expected_mission = ARMS[name] + if expected_mission is None: + assert cfg["MISSION_BANK"] is None, f"{name}: expected no mission bank, got {cfg['MISSION_BANK']!r}" + else: + assert cfg["MISSION_BANK"] == expected_mission, f"{name}: mission bank mismatch (got {cfg['MISSION_BANK']!r})" + + +def test_mission_bank_amplify_is_unity_for_bank_arms(arm): + name, _recipe, cfg = arm + _, expected_mission = ARMS[name] + if expected_mission is None: + return + assert cfg["MISSION_BANK_AMPLIFY"] == 1.0, ( + f"{name}: expected mission_bank_amplify=1.0 (plan default), got {cfg['MISSION_BANK_AMPLIFY']}" + ) diff --git a/tests/test_phase_boundary_bank.py b/tests/test_phase_boundary_bank.py new file mode 100644 index 0000000..612bf90 --- /dev/null +++ b/tests/test_phase_boundary_bank.py @@ -0,0 +1,105 @@ +"""Tests for the phase-boundary jitter bank (Phase 6 / P2).""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest + +from jaxborg.evaluation.jax_env_factory import make_jax_env +from jaxborg.scenarios.cc4.game_variants import CC4_STOCK +from jaxborg.scenarios.cc4.topology_numpy import ( + PHASE_BOUNDARIES_BANK, + get_phase_boundaries_bank, +) + +DEFAULT_BANK = [list(t) for t in PHASE_BOUNDARIES_BANK] + + +def test_default_bank_metadata(): + bank = get_phase_boundaries_bank() + assert bank.shape == (4, 3) + assert bank.dtype == np.int32 + # phase 0 always starts at step 0 + assert (bank[:, 0] == 0).all() + # boundaries strictly increasing within each entry + for entry in bank: + assert entry[0] < entry[1] < entry[2] + + +def test_bank_overrides_const_phase_boundaries(): + """Sampling from a 1-entry bank produces that exact triple in const.""" + target = [(0, 50, 100)] + env = make_jax_env(CC4_STOCK, phase_boundary_bank=target) + _, state = env.reset(jax.random.PRNGKey(0)) + np.testing.assert_array_equal(np.asarray(state.const.phase_boundaries), [0, 50, 100]) + + +def test_no_bank_preserves_canonical_boundaries(): + env = make_jax_env(CC4_STOCK) + _, state = env.reset(jax.random.PRNGKey(0)) + pb = np.asarray(state.const.phase_boundaries) + # canonical CC4 split for 500-step episodes (3 ~equal phases) + assert pb[0] == 0 + assert pb[1] > 0 and pb[2] > pb[1] + + +def test_determinism_same_key_same_index(): + env = make_jax_env(CC4_STOCK, phase_boundary_bank=DEFAULT_BANK) + key = jax.random.PRNGKey(7) + _, state_a = env.reset(key) + _, state_b = env.reset(key) + np.testing.assert_array_equal( + np.asarray(state_a.const.phase_boundaries), + np.asarray(state_b.const.phase_boundaries), + ) + + +def test_distribution_uniform_chi_square(): + """Across 10000 keys, each of 4 entries is sampled within ~5% of uniform.""" + env = make_jax_env(CC4_STOCK, phase_boundary_bank=DEFAULT_BANK) + keys = jax.random.split(jax.random.PRNGKey(123), 10000) + _, state = jax.vmap(env.reset)(keys) + # Use phase_boundaries[1] (the unique phase-1 start) as the bucket key — + # the four bank entries have phase 1 starts {167, 100, 200, 150}. + p1_starts = np.asarray(state.const.phase_boundaries[:, 1]) + targets = np.array([t[1] for t in PHASE_BOUNDARIES_BANK]) + counts = np.array([(p1_starts == t).sum() for t in targets]) + expected = 10000 / len(targets) + chi2 = float(((counts - expected) ** 2 / expected).sum()) + # df=3 chi-square critical at α=0.001 is 16.27. 0.001 picks up real bias + # while leaving room for natural sampling jitter. + assert chi2 < 16.27, f"phase-boundary bank under-uniform: counts={counts.tolist()}, chi2={chi2:.2f}" + + +def test_empty_bank_fast_path(): + env = make_jax_env(CC4_STOCK, phase_boundary_bank=[]) + _, state = env.reset(jax.random.PRNGKey(0)) + pb = np.asarray(state.const.phase_boundaries) + # Identical to no-bank canonical + env_ref = make_jax_env(CC4_STOCK) + _, state_ref = env_ref.reset(jax.random.PRNGKey(0)) + np.testing.assert_array_equal(pb, np.asarray(state_ref.const.phase_boundaries)) + + +def test_invalid_shape_raises(): + with pytest.raises(ValueError, match="phase_boundary_bank"): + make_jax_env(CC4_STOCK, phase_boundary_bank=[[0, 100], [0, 200]]) + + +def test_recipe_projection(): + from jaxborg.recipe import project_jax + + recipe = { + "meta": {"name": "test"}, + "algorithm": "ippo", + "core": {"lr": 1e-3, "gamma": 0.99, "gae_lambda": 0.95}, + "arch": {"name": "shared"}, + "train": { + "episode_length": 500, + "total_timesteps": 1000, + "phase_boundary_bank": [[0, 100, 300], [0, 200, 400]], + }, + } + cfg = project_jax(recipe) + assert cfg["PHASE_BOUNDARY_BANK"] == [[0, 100, 300], [0, 200, 400]] diff --git a/tests/test_phase_rewards_bank.py b/tests/test_phase_rewards_bank.py new file mode 100644 index 0000000..18bf7d3 --- /dev/null +++ b/tests/test_phase_rewards_bank.py @@ -0,0 +1,147 @@ +"""Tests for the phase_rewards crown-jewel rotation bank (Phase 6 / P3).""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest + +from jaxborg.constants import SUBNET_IDS +from jaxborg.evaluation.jax_env_factory import make_jax_env +from jaxborg.scenarios.cc4.game_variants import CC4_STOCK +from jaxborg.scenarios.cc4.topology_numpy import ( + _build_phase_rewards, + get_phase_rewards_bank, +) + + +def test_bank_metadata(): + bank = get_phase_rewards_bank() + # (N, MISSION_PHASES, NUM_SUBNETS, 3) + assert bank.ndim == 4 + assert bank.shape[1:] == (3, 9, 3) # MISSION_PHASES=3, NUM_SUBNETS=9, components=3 + assert bank.dtype == np.float32 + assert bank.shape[0] >= 2, "bank needs at least the canonical entry plus one rotation" + + +def test_bank_entry_0_matches_canonical(): + """bank[0] is the canonical phase_rewards table — legacy reproducibility.""" + bank = get_phase_rewards_bank() + canonical = _build_phase_rewards() + np.testing.assert_array_equal(bank[0], canonical) + + +def test_bank_entries_are_distinct(): + """No two bank entries should be identical — otherwise sampling is wasted.""" + bank = get_phase_rewards_bank() + n = bank.shape[0] + for i in range(n): + for j in range(i + 1, n): + assert not np.array_equal(bank[i], bank[j]), f"bank[{i}] == bank[{j}]" + + +def test_swap_entry_rotates_ops_a_ops_b(): + """Bank entry 1 should swap OPS_A ↔ OPS_B in phases 1 and 2.""" + bank = get_phase_rewards_bank() + canonical = bank[0] + swapped = bank[1] + OA, OB = SUBNET_IDS["OPERATIONAL_ZONE_A"], SUBNET_IDS["OPERATIONAL_ZONE_B"] + np.testing.assert_array_equal(swapped[1, OA], canonical[1, OB]) + np.testing.assert_array_equal(swapped[1, OB], canonical[1, OA]) + np.testing.assert_array_equal(swapped[2, OA], canonical[2, OB]) + np.testing.assert_array_equal(swapped[2, OB], canonical[2, OA]) + + +def test_env_uses_bank_when_supplied(): + """Single-entry custom bank → const.phase_rewards == that entry.""" + canonical = _build_phase_rewards() + custom = canonical.copy() + custom[0, 0, 0] = 42.0 # arbitrary marker + env = make_jax_env(CC4_STOCK, phase_rewards_bank=[custom]) + _, state = env.reset(jax.random.PRNGKey(0)) + np.testing.assert_array_equal(np.asarray(state.const.phase_rewards), custom) + + +def test_no_bank_preserves_canonical_phase_rewards(): + env = make_jax_env(CC4_STOCK) + _, state = env.reset(jax.random.PRNGKey(0)) + np.testing.assert_array_equal(np.asarray(state.const.phase_rewards), _build_phase_rewards()) + + +def test_determinism_same_key_same_entry(): + env = make_jax_env(CC4_STOCK, phase_rewards_bank=get_phase_rewards_bank()) + key = jax.random.PRNGKey(11) + _, state_a = env.reset(key) + _, state_b = env.reset(key) + np.testing.assert_array_equal( + np.asarray(state_a.const.phase_rewards), + np.asarray(state_b.const.phase_rewards), + ) + + +def test_distribution_samples_multiple_entries(): + """Across many resets, ≥3 distinct bank entries should be observed.""" + bank = get_phase_rewards_bank() + env = make_jax_env(CC4_STOCK, phase_rewards_bank=bank) + keys = jax.random.split(jax.random.PRNGKey(99), 256) + _, state = jax.vmap(env.reset)(keys) + pr = np.asarray(state.const.phase_rewards) # (256, MP, NS, 3) + signatures = {tuple(slab.flatten().tolist()) for slab in pr} + assert len(signatures) >= 3, f"only {len(signatures)} distinct bank entries hit in 256 resets" + + +def test_composes_with_mission_bank(): + """Phase-rewards bank applies BEFORE mission multiplier — both transform.""" + bank = get_phase_rewards_bank() + env = make_jax_env( + CC4_STOCK, + phase_rewards_bank=[bank[1]], # single non-canonical entry + mission_bank=[[3.0, 1.0, 1.0]], # single LWF=3 multiplier + ) + _, state = env.reset(jax.random.PRNGKey(0)) + expected = bank[1].copy() + expected[..., 0] *= 3.0 # LWF column + np.testing.assert_allclose(np.asarray(state.const.phase_rewards), expected, rtol=1e-5) + + +def test_invalid_shape_raises(): + bad = np.zeros((3, 9), dtype=np.float32) # missing the (3,) component axis + with pytest.raises(ValueError, match="phase_rewards_bank"): + make_jax_env(CC4_STOCK, phase_rewards_bank=[bad]) + + +def test_recipe_projection_bool_form(): + from jaxborg.recipe import project_jax + + recipe = { + "meta": {"name": "test"}, + "algorithm": "ippo", + "core": {"lr": 1e-3, "gamma": 0.99, "gae_lambda": 0.95}, + "arch": {"name": "shared"}, + "train": { + "episode_length": 500, + "total_timesteps": 1000, + "phase_rewards_bank": True, + }, + } + cfg = project_jax(recipe) + assert cfg["PHASE_REWARDS_BANK"] is not None + assert cfg["PHASE_REWARDS_BANK"].ndim == 4 + + +def test_recipe_projection_false_form(): + from jaxborg.recipe import project_jax + + recipe = { + "meta": {"name": "test"}, + "algorithm": "ippo", + "core": {"lr": 1e-3, "gamma": 0.99, "gae_lambda": 0.95}, + "arch": {"name": "shared"}, + "train": { + "episode_length": 500, + "total_timesteps": 1000, + "phase_rewards_bank": False, + }, + } + cfg = project_jax(recipe) + assert cfg["PHASE_REWARDS_BANK"] is None diff --git a/tests/test_recipe_phase6_smoke.py b/tests/test_recipe_phase6_smoke.py new file mode 100644 index 0000000..9448915 --- /dev/null +++ b/tests/test_recipe_phase6_smoke.py @@ -0,0 +1,129 @@ +"""Phase 6 C11 integration smoke test. + +Builds the env that arm C11 (both topology bank + mission bank) trains on, +resets a vmapped batch, steps through one full episode with sampled actions, +and asserts: + + 1. No NaN/Inf in rewards or observations. + 2. Both banks were exercised — at least 3 distinct topology snapshots and + at least 3 distinct mission-multiplier triples were observed across the + batch (catches a regression where one bank silently degrades to a + singleton, e.g. PRNG splitting bug). + +Uses a smaller batch than the production training run (64 envs × 50 steps +vs 1024 × 500) to keep the test fast under CPU JAX. Bank diversity is still +observable: 16-entry topology bank with 64 draws → expected ~4 hits/entry +under uniform sampling; same logic for the 4-entry mission bank. +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jaxborg.evaluation.jax_env_factory import make_jax_env +from jaxborg.recipe import load, project_jax + +pytestmark = pytest.mark.slow + +NUM_ENVS = 64 +NUM_STEPS = 50 + + +def _build_c11_env(): + recipe = load("cec_phase6_C11") + cfg = project_jax(recipe) + + topology_bank = cfg.get("TOPOLOGY_BANK") or None + return cfg, make_jax_env( + cfg["TRAIN_VARIANT"], + training_mode=True, + topology_path=list(topology_bank) if topology_bank else None, + mission_bank=cfg.get("MISSION_BANK"), + mission_bank_amplify=cfg.get("MISSION_BANK_AMPLIFY", 1.0), + ) + + +def test_c11_smoke_no_nan_and_bank_diversity(): + cfg, env = _build_c11_env() + + assert cfg.get("TOPOLOGY_BANK"), "C11 must have a non-empty TOPOLOGY_BANK" + assert cfg.get("MISSION_BANK"), "C11 must have a non-empty MISSION_BANK" + assert len(cfg["TOPOLOGY_BANK"]) == 16 + assert len(cfg["MISSION_BANK"]) == 4 + + keys = jax.random.split(jax.random.PRNGKey(0), NUM_ENVS) + obs, state = jax.vmap(env.reset)(keys) + + # Bank diversity check, post-reset. + # Topology: const.host_subnet is one of 16 distinct snapshots. + host_subnets = np.asarray(state.const.host_subnet) # (NUM_ENVS, GLOBAL_MAX_HOSTS) + topo_signatures = {tuple(row.tolist()) for row in host_subnets} + assert len(topo_signatures) >= 3, ( + f"topology bank under-sampled: only {len(topo_signatures)} distinct snapshots " + f"in {NUM_ENVS} resets (expected ≥3 from a 16-entry bank)" + ) + + # Mission: const.phase_rewards is the unscaled bank entry × multiplier triple. + # Extract phase 0, subnet 0 — its (LWF, ASF, RIA) row encodes the active triple + # up to a per-snapshot constant. Stacking across the batch yields the + # diversity signal regardless of which topology snapshot was drawn. + pr = np.asarray(state.const.phase_rewards) # (NUM_ENVS, MISSION_PHASES, NUM_SUBNETS, 3) + + # Normalize per-env by the per-snapshot baseline so the mission multiplier + # ratio is what we count. Use the first non-zero ratio across phases/subnets + # as the signature. + def _ratio_signature(env_idx: int) -> tuple[float, float, float]: + slab = pr[env_idx] # (PHASES, SUBNETS, 3) + # find a (phase, subnet) where the row is not all zero + nonzero = np.nonzero(np.linalg.norm(slab, axis=-1)) + if nonzero[0].size == 0: + return (0.0, 0.0, 0.0) + p, s = int(nonzero[0][0]), int(nonzero[1][0]) + triple = slab[p, s] + # quantize to 4 decimals to defang float dust + peak = float(np.max(np.abs(triple))) or 1.0 + return tuple(round(float(x) / peak, 4) for x in triple) + + mission_signatures = {_ratio_signature(i) for i in range(NUM_ENVS)} + assert len(mission_signatures) >= 3, ( + f"mission bank under-sampled: only {len(mission_signatures)} distinct triples " + f"in {NUM_ENVS} resets (expected ≥3 from a 4-entry bank): {mission_signatures}" + ) + + # JIT'd scan rollout — random actions per env per step, accumulating + # max(|reward|) and max(|obs|) so we can assert finiteness once at the end. + agents = list(env.agents) + action_dim = env.action_space(agents[0]).n + + @jax.jit + def _rollout(state, rng): + def _step(carry, _): + state, rng, max_r, max_o = carry + rng, rng_act, rng_step = jax.random.split(rng, 3) + act_keys = jax.random.split(rng_act, NUM_ENVS) + actions = { + a: jax.vmap( + lambda k, off=i: jax.random.randint(jax.random.fold_in(k, off), (), 0, action_dim, dtype=jnp.int32) + )(act_keys) + for i, a in enumerate(agents) + } + step_keys = jax.random.split(rng_step, NUM_ENVS) + obs, state, rewards, dones, info = jax.vmap(env.step)(step_keys, state, actions) + r_stack = jnp.stack([rewards[a] for a in agents]) + o_stack = jnp.concatenate([obs[a].reshape(-1) for a in agents]) + max_r = jnp.maximum(max_r, jnp.max(jnp.abs(r_stack))) + max_o = jnp.maximum(max_o, jnp.max(jnp.abs(o_stack))) + return (state, rng, max_r, max_o), None + + init_carry = (state, rng, jnp.float32(0.0), jnp.float32(0.0)) + (final_state, _, max_r, max_o), _ = jax.lax.scan(_step, init_carry, None, NUM_STEPS) + return final_state, max_r, max_o + + _, max_r, max_o = _rollout(state, jax.random.PRNGKey(1)) + max_r = float(max_r) + max_o = float(max_o) + assert np.isfinite(max_r), f"NaN/Inf in rewards across rollout (max |r| = {max_r})" + assert np.isfinite(max_o), f"NaN/Inf in observations across rollout (max |o| = {max_o})" diff --git a/tests/test_topology_bank_sampling.py b/tests/test_topology_bank_sampling.py new file mode 100644 index 0000000..8a2459f --- /dev/null +++ b/tests/test_topology_bank_sampling.py @@ -0,0 +1,124 @@ +"""Tests for the Phase 6 Axis A topology-shape bank. + +Validates: +* `_select_const` is deterministic in the PRNG key. +* `_select_const` samples uniformly across the bank (chi-square @ α=0.01). +* All 16 emitted snapshots load cleanly via ``load_topology``. +* All 16 snapshots pass ``_validate_resilience_topology(CIA_RESILIENCE)``. +* Resetting an env wired to the full bank produces multiple distinct + ``host_subnet`` arrays (i.e. the bank actually drives const variation). +""" + +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jaxborg.env import ScenarioEnv +from jaxborg.evaluation.jax_env_factory import _validate_resilience_topology, make_jax_env +from jaxborg.scenarios.cc4.game_variants import CIA_RESILIENCE +from jaxborg.scenarios.cc4.topology import load_topology + +REPO_ROOT = Path(__file__).resolve().parents[1] +BANK_DIR = REPO_ROOT / "scripts" / "dev" / "topology_bank" +BUILDER = REPO_ROOT / "scripts" / "dev" / "build_topology_bank.py" +EXPECTED_COUNT = 16 + + +def _bank_paths() -> list[Path]: + return sorted(BANK_DIR.glob("shape_*.snapshot.npz")) + + +@pytest.fixture(scope="module") +def bank_paths() -> list[Path]: + paths = _bank_paths() + if len(paths) != EXPECTED_COUNT: + # Build the bank if missing/incomplete so the test is hermetic. + subprocess.check_call( + [ + sys.executable, + str(BUILDER), + "--out-dir", + str(BANK_DIR), + "--count", + str(EXPECTED_COUNT), + "--seed", + "0", + ], + cwd=REPO_ROOT, + ) + paths = _bank_paths() + assert len(paths) == EXPECTED_COUNT, f"expected {EXPECTED_COUNT} snapshots, got {len(paths)}" + return paths + + +def test_select_const_is_deterministic(bank_paths: list[Path]) -> None: + """Same PRNG key → same bank index, same const tree.""" + env = ScenarioEnv(topology_path=bank_paths[:4]) + key = jax.random.PRNGKey(123) + a = env._select_const(key) + b = env._select_const(key) + assert jnp.array_equal(a.host_subnet, b.host_subnet) + assert jnp.array_equal(a.data_links, b.data_links) + assert jnp.array_equal(a.allowed_subnet_pairs, b.allowed_subnet_pairs) + + +def test_select_const_samples_uniformly(bank_paths: list[Path]) -> None: + """Across many keys with a 4-entry bank, distribution is ~uniform. + + Uses chi-square goodness-of-fit at α=0.01 (df=3, critical ≈ 11.345). + """ + bank_size = 4 + n_samples = 10000 + + keys = jax.random.split(jax.random.PRNGKey(7), n_samples) + # Mirror the bank-index draw inside ``_select_const`` exactly: a single + # ``randint`` over [0, bank_size) per key. + indices_fn = jax.jit(jax.vmap(lambda k: jax.random.randint(k, (), 0, bank_size))) + indices = np.asarray(indices_fn(keys)) + + counts = np.bincount(indices, minlength=bank_size) + expected = n_samples / bank_size + chi2 = float(((counts - expected) ** 2 / expected).sum()) + + # df = bank_size - 1 = 3; α=0.01 critical ≈ 11.345. + chi2_crit_001 = 11.345 + assert chi2 < chi2_crit_001, ( + f"chi-square={chi2:.3f} (counts={counts.tolist()}) exceeds α=0.01 critical {chi2_crit_001}; " + "sampling is non-uniform" + ) + + # Also sanity-check that no bin is more than 5% off expected — a stronger + # check that's redundant with chi-square but easier to interpret. + max_dev = max(abs(c - expected) / expected for c in counts) + assert max_dev < 0.05, f"max deviation {max_dev:.4f} exceeds 5% — counts={counts.tolist()}" + + +def test_each_snapshot_loads_cleanly(bank_paths: list[Path]) -> None: + for p in bank_paths: + const = load_topology(p) + # Sanity: const has the expected shape ABI. + assert const.host_subnet.shape == (const.host_active.shape[0],) + + +def test_each_snapshot_passes_resilience_validator(bank_paths: list[Path]) -> None: + # Single batched call to ensure every snapshot would be acceptable + # under the strictest variant we care about. + _validate_resilience_topology(CIA_RESILIENCE, bank_paths) + + +def test_env_with_full_bank_produces_distinct_consts(bank_paths: list[Path]) -> None: + """Reset 32 times across the full 16-shape bank → ≥3 distinct host_subnet arrays.""" + env = make_jax_env(CIA_RESILIENCE, topology_path=list(bank_paths)) + keys = jax.random.split(jax.random.PRNGKey(2026), 32) + seen: set[bytes] = set() + for k in keys: + _, state = env.reset(k) + seen.add(np.asarray(state.const.host_subnet).tobytes()) + assert len(seen) >= 3, f"expected ≥3 distinct host_subnet arrays across 32 resets, got {len(seen)}"