Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ uv run python scripts/eval/cc4_trajectory_eval.py \
--recipe recipes/resilience.yaml

# 2. Score them (CIA + resilience)
uv run python scripts/eval/cc4_score_trajectories.py trajs/resilience_seed42 \
uv run python scripts/eval/score_trajectories.py trajs/resilience_seed42 \
--recipe recipes/resilience.yaml

```
Expand Down
2 changes: 1 addition & 1 deletion recipes/resilience.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ cleanrl:
num_epochs: 4

eval:
# CIA scorer for cc4_score_trajectories.py.
# CIA scorer for score_trajectories.py.
# Only "resilience" is currently registered (auth/db/web role impacts).
cia_metric: resilience

Expand Down
94 changes: 94 additions & 0 deletions scripts/dev/check_red_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Empirical check: does each registered red selector bias as advertised?

For each selector, roll out N episodes of a sleep-blue env and count which
hosts red actually attacked. Group counts by role (NONE / AUTH / DB / WEB) and
print share-of-attacks per role.

Expectations from the PR description:
fsm — ~92% attacks on NONE (untagged) hosts (no bias)
resilience — AUTH+DB+WEB share lifts ~5× over fsm (target_weight=5)
cia_c — AUTH+DB share lifts ~10× over fsm; tagged hosts at FSM_R get
Impact/Degrade vs vanilla's Discover (action shift visible if
we count action types — out of scope here)
cia_i — AUTH+WEB share lifts ~10×
cia_a — AUTH+DB+WEB share lifts ~10×

This script doesn't enforce thresholds — it prints, you eyeball.
"""

from __future__ import annotations

# ruff: noqa: E402
import os
import sys
from collections import Counter
from pathlib import Path

os.environ.setdefault("JAX_PLATFORMS", "cpu")

ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(ROOT / "src"))

import jax
import jax.numpy as jnp

from jaxborg.parity.fsm_red_env import make_fsm_red_env
from jaxborg.scenarios.cc4.topology_roles import ROLE_AUTH, ROLE_DB, ROLE_NONE, ROLE_WEB

NUM_EPISODES = 3
EPISODE_STEPS = 30 # short — enough for bias to surface in expectation
SELECTORS = ("fsm", "resilience", "cia_c", "cia_i", "cia_a")
_ROLE_LABEL = {ROLE_NONE: "NONE", ROLE_AUTH: "AUTH", ROLE_DB: "DB", ROLE_WEB: "WEB"}


def rollout_one(env, key) -> Counter:
"""Run one episode; return counter of (role, attacked) for each red attack step."""
obs, state = env.reset(key)
counts: Counter = Counter()
blue_actions = {a: jnp.int32(0) for a in env.agents}
roles = state.extras["host_resilience_role"] # (GLOBAL_MAX_HOSTS,) int32

step_fn = jax.jit(env.step)
for _ in range(EPISODE_STEPS):
key, sk = jax.random.split(key)
obs, state, _rew, dones, _info = step_fn(sk, state, blue_actions)
# Each red agent's pending_target_host is the host they currently target.
# Active agents only: red_agent_active[r] && red_pending_ticks[r] >= 0.
targets = state.state.red_pending_target_host # (NUM_RED,)
active = state.state.red_agent_active # (NUM_RED,)
for r in range(targets.shape[0]):
if not bool(active[r]):
continue
host_idx = int(targets[r])
role = int(roles[host_idx])
counts[role] += 1
if bool(dones["__all__"]):
break
return counts


def main():
print(f"\n{'selector':<12s} {'NONE':>10s} {'AUTH':>10s} {'DB':>10s} {'WEB':>10s} {'tagged%':>8s}")
print("-" * 70)
print(
"(All selectors run with role_assignment='resilience' so the baseline is\n"
" apples-to-apples — fsm shows uniform-attack share on tagged hosts.)\n"
)
for name in SELECTORS:
# Force role assignment for all selectors so "NONE %" reflects the same
# underlying tag set; biased rows then show their lift vs the fsm baseline.
env = make_fsm_red_env(num_steps=EPISODE_STEPS, red_agent=name, role_assignment="resilience")
agg: Counter = Counter()
for ep in range(NUM_EPISODES):
agg += rollout_one(env, jax.random.PRNGKey(1000 + ep))
total = sum(agg.values()) or 1
tagged = agg[ROLE_AUTH] + agg[ROLE_DB] + agg[ROLE_WEB]
share = lambda r: f"{agg[r] / total * 100:>9.1f}%" # noqa: E731
print(
f"{name:<12s} {share(ROLE_NONE)} {share(ROLE_AUTH)} "
f"{share(ROLE_DB)} {share(ROLE_WEB)} {tagged / total * 100:>7.1f}%"
)


if __name__ == "__main__":
main()
22 changes: 8 additions & 14 deletions scripts/eval/cc4_trajectory_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
other availability-event metrics derivable post-hoc without re-rolling out)
- one `footer` record (total reward, steps)

Trajectories are then scored post-hoc by `cc4_score_trajectories.py`. Decoupling
Trajectories are then scored post-hoc by `score_trajectories.py`. Decoupling
rollout from scoring lets us re-evaluate with new alignment metrics without
re-running CybORG (CPU-bound, ~2 min/episode).
"""
Expand All @@ -33,7 +33,7 @@

from jaxborg.constants import BLUE_OBS_SIZE
from jaxborg.evaluation.cyborg_runner import load_torch_policy
from jaxborg.scenarios.cc4.topology_roles import assign_resilience_roles
from jaxborg.scenarios.cc4.cyborg_resilience_agents import inject_role_map

NUM_AGENTS = 5
AGENT_IDS = [f"blue_agent_{i}" for i in range(NUM_AGENTS)]
Expand All @@ -44,23 +44,16 @@

def make_env(seed, red_agent: str = "finite_state", target_weight: float = 5.0):
from CybORG import CybORG
from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent
from CybORG.Agents import EnterpriseGreenAgent, SleepAgent
from CybORG.Agents.Wrappers import EnterpriseMAE
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator

from jaxborg.scenarios.cc4.cyborg_resilience_agents import ResilienceRedAgent

_red_classes = {
"finite_state": FiniteStateRedAgent,
"sleep": SleepAgent,
"resilience": ResilienceRedAgent.with_weight(target_weight),
}
red_cls = _red_classes.get(red_agent, FiniteStateRedAgent)
from jaxborg.evaluation.cyborg_red_dispatch import cyborg_red_class

sg = EnterpriseScenarioGenerator(
blue_agent_class=SleepAgent,
green_agent_class=EnterpriseGreenAgent,
red_agent_class=red_cls,
red_agent_class=cyborg_red_class(red_agent, target_weight),
steps=EPISODE_LENGTH,
)
return EnterpriseMAE(CybORG(sg, "sim", seed=seed))
Expand Down Expand Up @@ -218,8 +211,9 @@ def evaluate(model_path, episodes, seed, deterministic, output_dir, tag, recipe_
env = make_env(ep_seed, red_agent=red_agent, target_weight=target_weight)
resilience_roles = None
if resilience_mode:
hosts = list(env.unwrapped.environment_controller.state.hosts.keys())
resilience_roles = assign_resilience_roles(hosts)
# Build per-episode role map from the env's full host list and push
# it into every red agent so red bias and recorded roles agree.
resilience_roles = inject_role_map(env, ep_seed)
out_path = output_dir / f"{tag}_seed{ep_seed}.jsonl"
r, n = rollout_episode(env, agent, device, deterministic, ep_seed, model_path, out_path, resilience_roles)
rewards.append(r)
Expand Down
35 changes: 25 additions & 10 deletions scripts/eval/eval_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def main():
parser.add_argument("--episodes", type=int, default=10, help="Episodes per seed")
parser.add_argument("--seeds", type=str, default="42-51", help="e.g. '42-51' or '42,43,44'")
parser.add_argument("--deterministic", action="store_true")
parser.add_argument(
"--workers",
type=int,
default=max(1, (os.cpu_count() or 4) - 2),
help="Parallel rollout workers (1 = single process). Default: cpu_count() - 2.",
)
parser.add_argument("--output", type=str, default=None, help="Override result jsonl path")
args = parser.parse_args()

Expand All @@ -97,9 +103,7 @@ def main():
seeds = _parse_seeds(args.seeds)

if trained_backend == "cyborg":
import torch

from jaxborg.evaluation.cyborg_runner import evaluate_on_cyborg, load_torch_policy_from_recipe
from jaxborg.evaluation.cyborg_runner import evaluate_on_cyborg
from jaxborg.recipe import project_eval

recipe = read_sidecar(model_path)
Expand All @@ -109,36 +113,47 @@ def main():
print(f"Loaded recipe sidecar: {recipe.get('meta', {}).get('name', '?')}", flush=True)
print(
f" trained=cyborg arch={recipe['arch']['name']} seeds={seeds} "
f"eps/seed={args.episodes} red_agent={red_agent}",
f"eps/seed={args.episodes} red_agent={red_agent} workers={args.workers}",
flush=True,
)

state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
agent = load_torch_policy_from_recipe(recipe, state_dict)

t0 = time.perf_counter()
rewards, seed_log = evaluate_on_cyborg(
agent,
model_path,
seeds=seeds,
episodes_per_seed=args.episodes,
deterministic=args.deterministic,
red_agent=red_agent,
target_weight=target_weight,
workers=args.workers,
)
wall = time.perf_counter() - t0
else:
from jaxborg.evaluation.jax_runner import evaluate_jax_on_cyborg
from jaxborg.recipe import project_eval

recipe = read_sidecar(model_path)
eval_cfg = project_eval(recipe)
red_agent = eval_cfg["red_agent"]
target_weight = eval_cfg["resilience_target_weight"]

t0 = time.perf_counter()
rewards, seed_log, recipe = evaluate_jax_on_cyborg(
model_path,
seeds=seeds,
episodes_per_seed=args.episodes,
deterministic=args.deterministic,
red_agent=red_agent,
target_weight=target_weight,
workers=args.workers,
)
wall = time.perf_counter() - t0
print(f"Loaded recipe (sidecar or fallback): {recipe.get('meta', {}).get('name', '?')}", flush=True)
print(f" trained=jax arch={recipe['arch']['name']} seeds={seeds} eps/seed={args.episodes}", flush=True)
print(
f" trained=jax arch={recipe['arch']['name']} seeds={seeds} "
f"eps/seed={args.episodes} red_agent={red_agent} workers={args.workers}",
flush=True,
)

m = mean(rewards)
s = stdev(rewards) if len(rewards) > 1 else 0.0
Expand All @@ -152,7 +167,7 @@ def main():
"recipe_path": recipe.get("meta", {}).get("source_path") or recipe.get("__source_path__", ""),
"trained_backend": trained_backend,
"eval_env": "cyborg",
"red_agent": red_agent if trained_backend == "cyborg" else "finite_state",
"red_agent": red_agent,
"seeds": seeds,
"episodes_per_seed": args.episodes,
"stochastic": not args.deterministic,
Expand Down
36 changes: 21 additions & 15 deletions scripts/train/algorithms/ippo_cyborg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,47 +52,53 @@

def make_cyborg_env(red_agent: str = "finite_state", target_weight: float = 5.0):
from CybORG import CybORG
from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent
from CybORG.Agents import EnterpriseGreenAgent, SleepAgent
from CybORG.Agents.Wrappers import EnterpriseMAE
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator

from jaxborg.scenarios.cc4.cyborg_resilience_agents import ARedAgent, CRedAgent, IRedAgent, ResilienceRedAgent

_red_classes = {
"finite_state": FiniteStateRedAgent,
"sleep": SleepAgent,
"resilience": ResilienceRedAgent.with_weight(target_weight),
"c": CRedAgent.with_weight(target_weight),
"i": IRedAgent.with_weight(target_weight),
"a": ARedAgent.with_weight(target_weight),
}
red_cls = _red_classes.get(red_agent, FiniteStateRedAgent)
from jaxborg.evaluation.cyborg_red_dispatch import cyborg_red_class

sg = EnterpriseScenarioGenerator(
blue_agent_class=SleepAgent,
green_agent_class=EnterpriseGreenAgent,
red_agent_class=red_cls,
red_agent_class=cyborg_red_class(red_agent, target_weight),
steps=500,
)
return EnterpriseMAE(CybORG(scenario_generator=sg))


def env_worker(pipe, env_id, red_agent: str = "finite_state", target_weight: float = 5.0):
import random as _random

from jaxborg.scenarios.cc4.cyborg_resilience_agents import inject_role_map

signal.signal(signal.SIGINT, signal.SIG_IGN)
env = make_cyborg_env(red_agent=red_agent, target_weight=target_weight)
# Per-worker RNG for per-episode resilience-role seeds. Distinct per worker
# so vmap-equivalent envs see different role-map sequences. Reproducible
# given env_id.
seed_rng = _random.Random(env_id)
needs_roles = red_agent in ("resilience", "c", "i", "a", "cia_c", "cia_i", "cia_a")

def _reset_and_inject():
obs, info = env.reset()
if needs_roles:
inject_role_map(env, ep_seed=seed_rng.randrange(2**31))
return obs, info

while True:
try:
cmd, data = pipe.recv()
except EOFError:
break
if cmd == "reset":
obs, info = env.reset()
obs, info = _reset_and_inject()
pipe.send((obs, info))
elif cmd == "step":
obs, rew, term, trunc, info = env.step(data)
done = any(term.values()) or any(trunc.values())
if done:
obs, info = env.reset()
obs, info = _reset_and_inject()
pipe.send((obs, rew, done, info))
elif cmd == "close":
pipe.close()
Expand Down
32 changes: 12 additions & 20 deletions scripts/train/algorithms/ippo_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
from jaxborg.checkpoint import save_jax_params, write_sidecar
from jaxborg.metrics_schema import make_row
from jaxborg.mlflow_setup import start_run
from jaxborg.parity.fsm_red_env import FsmRedCC4Env
from jaxborg.parity.resilience_red_env import ResilienceRedCC4Env
from jaxborg.parity.fsm_red_env import make_fsm_red_env
from jaxborg.policies import make_jax_policy
from jaxborg.recipe import load as load_recipe
from jaxborg.recipe import project_jax
Expand Down Expand Up @@ -86,18 +85,13 @@ def compute_value_loss(value, old_value, targets, clip_eps, clip_value_loss):
def make_train(config, network):
"""Build env and a single JIT'd collect_and_update fn from a flat config."""
num_envs = config["NUM_ENVS"]
_env_kwargs = dict(
_red_agent = config.get("RED_AGENT") or ("resilience" if config.get("RESILIENCE_MODE", False) else "fsm")
inner_env = make_fsm_red_env(
num_steps=500,
training_mode=bool(config.get("TRAINING_MODE", True)),
red_agent=_red_agent,
target_weight=float(config.get("RESILIENCE_TARGET_WEIGHT", 5.0)),
)
_red_agent = config.get("RED_AGENT", "")
if _red_agent in ("c", "i", "a"):
inner_env = ResilienceRedCC4Env(cia_target=_red_agent, **_env_kwargs)
elif _red_agent == "resilience" or config.get("RESILIENCE_MODE", False):
_env_kwargs["target_weight"] = float(config.get("RESILIENCE_TARGET_WEIGHT", 5.0))
inner_env = ResilienceRedCC4Env(**_env_kwargs)
else:
inner_env = FsmRedCC4Env(**_env_kwargs)
agents = list(inner_env.agents)
num_agents = inner_env.num_agents
config["NUM_ACTORS"] = num_agents * num_envs
Expand Down Expand Up @@ -341,15 +335,13 @@ def main():
print(f"XLA compilation cache: {cache_dir}", flush=True)

# Build a throwaway env to get action_dim for network init.
_probe_kwargs = dict(num_steps=500, topology_mode=config.get("TOPOLOGY_MODE", "generative"))
_red_agent = config.get("RED_AGENT", "")
if _red_agent in ("c", "i", "a"):
inner_env = ResilienceRedCC4Env(cia_target=_red_agent, **_probe_kwargs)
elif _red_agent == "resilience" or config.get("RESILIENCE_MODE", False):
_probe_kwargs["target_weight"] = float(config.get("RESILIENCE_TARGET_WEIGHT", 5.0))
inner_env = ResilienceRedCC4Env(**_probe_kwargs)
else:
inner_env = FsmRedCC4Env(**_probe_kwargs)
_red_agent = config.get("RED_AGENT") or ("resilience" if config.get("RESILIENCE_MODE", False) else "fsm")
inner_env = make_fsm_red_env(
num_steps=500,
topology_mode=config.get("TOPOLOGY_MODE", "generative"),
red_agent=_red_agent,
target_weight=float(config.get("RESILIENCE_TARGET_WEIGHT", 5.0)),
)
action_dim = inner_env.action_space(inner_env.agents[0]).n
network = make_jax_policy(
recipe["arch"]["name"],
Expand Down
Loading
Loading